Source code for concepts.dsl.dsl_functions

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : dsl_functions.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 10/19/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""Data structures representing functions in a DSL.

Most importantly, this file contains the following classes:

- :class:`FunctionType`: the underlying type of a function, including argument types and return types.
- :class:`Function`: the function object, which is a callable object that can be used in expressions. They have names and types.

This file also implements a data structure for overloaded functions: :class:`OverloadedFunctionType`.
Internally, it contains a list of :class:`FunctionType` objects, and it is used to represent overloaded functions.
There are a few argument resolution methods implemented for both :class:`FunctionType` and :class:`OverloadedFunctionType`.
"""

import itertools
import contextlib
import inspect
import re
from typing import TYPE_CHECKING, Any, Union, Sequence, Tuple, List, Dict, Optional, Callable
from dataclasses import dataclass

import jacinle
from jacinle.utils.cache import cached_property
from jacinle.utils.defaults import option_context
from jacinle.utils.printing import indent_text

from concepts.dsl.dsl_types import TypeBase, ObjectType, ValueType, SequenceType, TupleType, UniformSequenceType, ConstantType, UnionType, Variable
from concepts.dsl.dsl_types import FormatContext, get_format_context

if TYPE_CHECKING:
    from concepts.dsl.expression import Expression, FunctionApplicationExpression

__all__ = [
    'FunctionArgumentResolutionError', 'FunctionArgumentResolutionContext', 'get_function_argument_resolution_context',
    'FunctionArgumentUnset', 'AnonymousFunctionArgumentGenerator',
    'FunctionArgumentType', 'FunctionArgumentListType', 'FunctionReturnType', 'FunctionType',
    'OverloadedFunctionResolution', 'OverloadedFunctionAmbiguousResolutions', 'OverloadedFunctionType',
    'FunctionTyping',
    'FunctionOverriddenCallList', 'FunctionDerivedExpressionList', 'FunctionResolvedFromRecord', 'Function',
]


[docs] class FunctionArgumentResolutionError(Exception): """Exception raised when the function argument resolution fails.""" pass
[docs] class FunctionArgumentResolutionContext(option_context( '_FunctionArgumentResolutionContext', check_missing=True, check_type=True, check_overloaded_ambiguity=True, exc_verbose=True )): """A context manager for controlling the function argument resolution. Attributes: check_missing (bool): whether to check if the function argument is missing. check_type (bool): whether to check if the function argument type is correct. check_overloaded_ambiguity (bool): whether to check if the function argument resolution is ambiguous. exc_verbose (bool):wWhether to print verbose error message. """
[docs] @contextlib.contextmanager def exc(self, exc_type=None, from_=None): if self.exc_verbose: yield else: if exc_type is None: exc_type = FunctionArgumentResolutionError if from_ is not None: raise exc_type() from from_ raise exc_type()
get_function_argument_resolution_context: Callable[[], FunctionArgumentResolutionContext] = FunctionArgumentResolutionContext.get_default """Get the current function argument resolution context.""" FunctionArgumentUnset = object() """A placeholder indicating that the argument is not specified."""
[docs] class AnonymousFunctionArgumentGenerator(object): """A generator for anonymous function arguments."""
[docs] def __init__(self, template='_t{i:d}'): self.template = template self.counter = 0
@property def nr_generated(self) -> int: return self.counter
[docs] def gen(self, n: Optional[int] = None) -> Union[str, List[str]]: if n is None: self.counter += 1 return self.template.format(i=self.counter) return [self.gen() for _ in range(n)]
FunctionArgumentType = Union[ObjectType, ValueType, 'FunctionType', SequenceType] """Acceptable types for function arguments. See the documentation of `FunctionType` for more details.""" FunctionArgumentListType = Union[Sequence[FunctionArgumentType], Sequence[Variable], Dict[str, FunctionArgumentType]] """Acceptable types for function argument lists. See the documentation of `FunctionType` for more details.""" FunctionReturnType = Union[ObjectType, ValueType, 'FunctionType', SequenceType] """Acceptable types for function return types. See the documentation of `FunctionType` for more details."""
[docs] class FunctionType(TypeBase): """FunctionType defines the signature of a function.""" argument_types: Tuple[FunctionArgumentType, ...] """The types of the arguments.""" argument_names: Tuple[str, ...] """The names of the arguments.""" arguments: Tuple[Variable, ...] """The argument list composed of `Variable` instances.""" arguments_dict: Dict[str, FunctionArgumentType] """The arguments as a dict, as mappings from argument names to argument types.""" arguments_name2index: Dict[str, int] """The mapping from argument names to argument indices.""" return_type: FunctionReturnType """The return type of the function type.""" return_name: Optional[Union[str, Tuple[str, ...]]] """The name of the return value.""" is_generator_function: bool """Whether the function is a generator function.""" is_simple_arguments: bool """Whether all arguments are not batched-list types.""" is_singular_return: bool """Whether there is only one return value.""" is_cacheable: bool """Whether the function is cacheable."""
[docs] def __init__( self, arguments: FunctionArgumentListType, return_type: FunctionReturnType, argument_names: Optional[Sequence[str]] = None, return_name: Optional[Union[str, Sequence[str]]] = None, is_generator_function: bool = False, alias: Optional[str] = None ): """Initialize the function type. There are four ways to specify the arguments of the function type: 1. A list of types, in which case the name of each argument is the index of the argument, using the format `#{index}`. 2. A list of types as the `arguments`, and a list of names as the `argument_names`. 3. A list of variables, in which case the name of each argument is the name of the variable. 4. A dictionary of {name: type}, in which case the order of the arguments is the order of the keys. The return type can be either a single type or a tuple of types (multi-return types). Args: arguments: The arguments of the function type. When it is a list, the name of the arguments will be automatically generated. When it is a dict, the name of the arguments will be the keys of the dict. return_type: The return type of the function type. argument_names: The names of the arguments. return_name: The name of the return value. is_generator_function: Whether the function is a generator function. alias: The alias name of the function type. """ self.arguments = None # noqa self.arguments_dict = None # noqa if isinstance(arguments, (list, tuple)): if len(arguments) == 0: self.arguments = tuple() self.arguments_dict = dict() self.argument_names = tuple() self.argument_types = tuple() elif isinstance(arguments[0], Variable): assert argument_names is None, 'Cannot specify both `arguments` and `argument_names`.' self.argument_types = tuple(arg.dtype.unwrap_alias() for arg in arguments) self.argument_names = tuple(arg.name for arg in arguments) self.arguments = arguments else: if argument_names is None: argument_names = tuple('_' + str(i) for i in range(len(arguments))) else: assert len(arguments) == len(argument_names), 'The length of `arguments` and `argument_names` must be the same.' self.argument_types = tuple([arg.unwrap_alias() for arg in arguments]) self.argument_names = tuple(argument_names) elif isinstance(arguments, dict): assert argument_names is None, 'Cannot specify both `arguments` and `argument_names`.' self.argument_names = tuple(arguments.keys()) self.argument_types = tuple([arg.unwrap_alias() for arg in arguments.values()]) self.arguments_dict = arguments.copy() else: raise TypeError(f'Invalid argument types: {arguments}. Must be a list or a dict.') if self.arguments is None: self.arguments = tuple(Variable(name, dtype.unwrap_alias()) for name, dtype in zip(self.argument_names, self.argument_types)) if self.arguments_dict is None: self.arguments_dict = {name: dtype for name, dtype in zip(self.argument_names, self.argument_types)} if isinstance(return_type, TupleType): self.return_type = return_type if return_name is None: self.return_name = None else: self.return_name = tuple(return_name) assert len(return_type.element_types) == len(self.return_name), 'The length of `return_type` and `return_name` must be the same.' # NB(Jiayuan Mao @ 2024/03/10): I commented out this behavior but I'm not sure if it's correct. # if len(return_type.element_types) == 1: # self.return_type = self.return_type.element_types[0] # if self.return_name is not None: # self.return_name = self.return_name[0] else: if return_name is not None and not isinstance(return_name, str): raise ValueError(f'Invalid return name: {return_name}. Must be a string when the function has a singular return type') self.return_type = return_type self.return_name = return_name self.is_generator_function = is_generator_function self.is_simple_arguments = all(not isinstance(arg, UniformSequenceType) for arg in self.argument_types) self.is_singular_return = not isinstance(self.return_type, TupleType) self.is_cacheable = self._gen_is_cacheable() super().__init__(self._gen_typename(), alias=alias)
def _gen_typename(self) -> str: if self.is_generator_function: return '(' + ', '.join([str(arg) for arg in self.arguments]) + ') -> Gen[' + str(self.return_type) + ']' return '(' + ', '.join([str(arg) for arg in self.arguments]) + ') -> ' + str(self.return_type) def _gen_is_cacheable(self): if self.is_generator_function: return False for arg_def in self.arguments: if isinstance(arg_def, ValueType): return False return True @property def nr_arguments(self) -> int: """Return the number of arguments.""" return len(self.argument_types) @cached_property def nr_object_arguments(self) -> int: """Return the number of arguments that are ObjectType-ed.""" return len(list(filter(lambda x: isinstance(x, ValueType), self.argument_types))) @cached_property def nr_value_arguments(self) -> int: """Return the number of arguments that are ValueType-ed.""" return len(list(filter(lambda x: isinstance(x, ValueType), self.argument_types))) @cached_property def nr_variable_arguments(self) -> int: """Return the number of arguments that are VariableType-ed (i.e. is ValueType-ed but not ConstantType-ed.""" return len(list(filter(lambda x: isinstance(x, ValueType) and not isinstance(x, ConstantType), self.argument_types))) @cached_property def nr_constant_arguments(self) -> int: """Return the number of arguments that are ConstantType-ed.""" return len(list(filter(lambda x: isinstance(x, ConstantType), self.argument_types))) @cached_property def arguments_name2index(self): assert self.argument_names is not None return {v: k for k, v in enumerate(self.argument_names)}
[docs] @classmethod def from_annotation(cls, function: Callable, sig: Optional[inspect.Signature] = None) -> Union['FunctionType', 'OverloadedFunctionType']: """Create a FunctionType from a function annotation. Args: function: The function. sig: The signature of the function. Returns: Union[FunctionType, OverloadedFunctionType]: The function type. """ if sig is None: sig = inspect.signature(function) argument_types = list() argument_names = list() for i, (name, param) in enumerate(sig.parameters.items()): if i == 0 and name == 'self': continue # is an instancemethod. if i == 0 and name == 'cls': continue # is a classmethod. argument_names.append(name) argument_types.append(param.annotation) return_type = sig.return_annotation if inspect._empty in argument_types or return_type is inspect._empty: raise FunctionArgumentResolutionError(f'Incomplete argument and return type annotation for {function}.') function_type = cls(argument_types, return_type, argument_names=argument_names) for arg_type in function_type.argument_types: if isinstance(arg_type, UnionType): return OverloadedFunctionType.from_function_type_with_union_arguments(function_type) return function_type
[docs] def resolve_args(self, *args: Any, **kwargs: Any) -> List[Any]: """Resolve the arguments to the function type. If you want to specify a specific "positional" argument by its index, use `_{index}` as the name of the argument. Args: *args: The positional arguments. **kwargs: The keyword arguments. Returns: A list of argument values. """ resolution_context = get_function_argument_resolution_context() # Construct a mapping from the name of arguments to their indices. name2index = {f'#{i}': i for i in range(self.nr_arguments)} name2index.update(self.arguments_name2index) arguments = [FunctionArgumentUnset for _ in range(self.nr_arguments)] if len(args) + len(kwargs) > self.nr_arguments: with resolution_context.exc(): raise FunctionArgumentResolutionError(f'Function {self} takes {len(self.argument_types)} arguments, got {len(args) + len(kwargs)}.') for i in range(len(args)): arguments[i] = args[i] for k, v in kwargs.items(): if k not in name2index: with resolution_context.exc(): raise FunctionArgumentResolutionError(f'Got unknown keyword argument: {k} when invoking function {self}.') i = name2index[k] if arguments[i] is not FunctionArgumentUnset: with resolution_context.exc(): raise FunctionArgumentResolutionError(f'Got duplicated argument for keyword argument: {k} when invoking function {self}.') arguments[i] = v if resolution_context.check_missing: for i in range(self.nr_arguments): if arguments[i] is FunctionArgumentUnset: with resolution_context.exc(): raise FunctionArgumentResolutionError(f'Missing argument {self.argument_names[i]} when invoking function {self}.') if resolution_context.check_type: from concepts.dsl.expression import get_types argument_types = get_types(arguments) for i in range(self.nr_arguments): if argument_types[i] is not FunctionArgumentUnset and not argument_types[i].downcast_compatible(self.argument_types[i]): with resolution_context.exc(): raise FunctionArgumentResolutionError(f'Typecheck failed for argument {self.argument_names[i]} while invoking the function {self}.\nInvoked with types: {argument_types}.') return arguments
[docs] @dataclass class OverloadedFunctionResolution(object): """The data structure for storing the result of resolving an overloaded function.""" type_index: int """The index of the function type that matches the expected signature.""" ftype: FunctionType """The function type that matches the expected signature.""" arguments: List[Any] """The resolved arguments."""
[docs] class OverloadedFunctionAmbiguousResolutions(list, List[OverloadedFunctionResolution]): pass
[docs] class OverloadedFunctionType(TypeBase): types: Tuple[FunctionType]
[docs] def __init__( self, types: Sequence[Union['OverloadedFunctionType', FunctionType]], alias: Optional[str] = None ): types_flatten: List[FunctionType] = list() for ftype in types: if isinstance(ftype, OverloadedFunctionType): types_flatten.extend(ftype.types) else: assert isinstance(type, FunctionType) types_flatten.append(ftype) self.types = tuple(types_flatten) super().__init__(self._gen_typename(), alias=alias)
def _gen_typename(self) -> str: return 'Overloaded{' + ','.join([x.typename for x in self.types]) + '}' @property def nr_types(self) -> int: """Return the number of sub-types.""" return len(self.types)
[docs] @classmethod def from_function_type_with_union_arguments(cls, function_type: FunctionType): """Create an OverloadedFunctionType from a FunctionType with Union-Typed arguments.""" product_bases = list() for arg_type in function_type.argument_types: if isinstance(arg_type, UnionType): product_bases.append(arg_type.types) else: product_bases.append([arg_type]) product_types = tuple( FunctionType( arg_type, function_type.return_type, argument_names=function_type.argument_names ) for arg_type in itertools.product(*product_bases) ) return cls(product_types, alias=function_type.typename)
[docs] def resolve_type_and_args(self, *args: Any, **kwargs: Any) -> Union[OverloadedFunctionResolution, OverloadedFunctionAmbiguousResolutions]: """Resolve the exact sub-function type being called and the argument list. Args: *args: The positional arguments. **kwargs: The keyword arguments. Returns: A :class:`OverloadedFunctionResolution` object if the resolution is unambiguous, or a :class:`OverloadedFunctionAmbiguousResolutions` object if the resolution is ambiguous. The ambiguity resolution object will only be returned if the ``check_overloaded_ambiguity`` flag is set to False in the :class:`FunctionArgumentResolutionContext`. - The :class:`OverloadedFunctionResolution` object contains the index of the sub-function type being called, the sub-function type, and the resolved argument list. - The :class:`OverloadedFunctionAmbiguousResolutions` object contains a list of :class:`OverloadedFunctionResolution` objects. """ resolution_context = get_function_argument_resolution_context() success_results = list() exceptions = list() for i, ftype in enumerate(self.types): try: arguments = ftype.resolve_args(*args, **kwargs) success_results.append(OverloadedFunctionResolution(i, ftype, arguments)) except FunctionArgumentResolutionError as e: exceptions.append(e) if len(success_results) == 1: return success_results[0] elif len(success_results) == 0: with resolution_context.exc(): fmt = 'Failed to resolve overloaded function{}.\n'.format('' if self.typename is None else ' ' + self.typename) fmt += 'Detailed messages are:\n' for ftype, r in zip(self.types, exceptions): this_fmt = 'Trying ' + str(ftype) + ':\n' this_fmt += indent_text(str(r)) fmt += indent_text(this_fmt) + '\n' raise FunctionArgumentResolutionError(fmt.rstrip()) else: if resolution_context.check_overloaded_ambiguity: with resolution_context.exc(): fmt = 'Got ambiguous application of overloaded function{}.\n'.format('' if self.typename is None else ' ' + self.typename) fmt += 'Candidates are:\n' for r in success_results: fmt += indent_text(str(r[1])) + '\n' fmt += 'Invoked with arguments: {}.'.format(str(success_results[0][2])) raise FunctionArgumentResolutionError(fmt) else: return OverloadedFunctionAmbiguousResolutions(success_results)
class _FunctionTypingSugarInner(object): def __init__(self, return_type): self.return_type = return_type def __call__(self, *args, **kwargs): if len(args) == 0 and len(kwargs) == 0: return FunctionType(tuple(), self.return_type) elif len(args) != 0: assert len(kwargs) == 0, 'Only support all positional arguments or all positional keyword arguments.' return FunctionType(args, self.return_type) elif len(kwargs) != 0: assert len(args) == 0, 'Only support all positional arguments or all positional keyword arguments.' return FunctionType(tuple(kwargs.values()), self.return_type, tuple(kwargs.keys())) raise ValueError('Unreachable.') class _FunctionTypingSugar(object): def __getitem__(self, return_type): return _FunctionTypingSugarInner(return_type) """FunctionTyping is a language-sugar constructor for function types. For example: `FunctionTypingp[BOOL](INT64, INT64)` creates a function type with two INT64 arguments and a BOOL return type.""" FunctionTyping = _FunctionTypingSugar()
[docs] class FunctionOverriddenCallList(list, List[Callable]): """A data structure that holds multiple overridden __call__ implementations for a function. This is only useful when we are partial evaluating a function (and when the actual function type can not be resolved.) """ pass
[docs] class FunctionDerivedExpressionList(list, List['Expression']): """A data structure that holds multiple derived expressions for a function.""" pass
[docs] @dataclass class FunctionResolvedFromRecord(object): function: Callable ftype_index: Union[int, Tuple[int, ...]]
[docs] class Function(object): """A function object holds a function type and an optional overridden __call__. The function object holds an additional field called `overridden_call`, which isa callable function. This field is used to override the __call__ method of the function object. By default, the __call__ function returns a FunctionApplication object, which contains the name of the function and a list of arguments. However, when `overridden_call` is set, the __call__ method will return the result of calling `overridden_call` with the arguments. """
[docs] def __init__( self, name: str, ftype: Union[FunctionType, OverloadedFunctionType], derived_expression: Optional[Union['Expression', FunctionDerivedExpressionList]] = None, overridden_call: Optional[Union[Callable, FunctionOverriddenCallList]] = None, resolved_from: Optional[FunctionResolvedFromRecord] = None, function_body: Optional[Union[Callable, Sequence[Callable]]] = None ): """ Args: name: the name of the function. ftype: the function type. derived_expression: the expression that this function is derived from. overridden_call: the overridden call function. resolved_from: the record of the function that this function is resolved from. This is used for handling partial evaluation and function specialization (for overloadded functions). function_body: the function body. """ self.ftype = ftype self.derived_expression = derived_expression self.overridden_call = overridden_call self.resolved_from = resolved_from if isinstance(self.ftype, OverloadedFunctionType) and isinstance(self.overridden_call, FunctionOverriddenCallList): assert self.ftype.nr_types == len(self.overridden_call) self.name = name self.function_body = function_body # the function body defined during the declaration. if self.derived_expression is None and self.overridden_call is not None: if self.is_overloaded: self.derived_expression = FunctionDerivedExpressionList() for i in range(self.ftype.nr_types): self.derived_expression.append(_gen_expression_from_overridden_call(self.ftype.types[i], self.overridden_call[i])) else: self.derived_expression = _gen_expression_from_overridden_call(ftype, self.overridden_call) self.is_derived = self.derived_expression is not None
[docs] def set_function_name(self, function_name: str): """Set the function name.""" self.name = function_name
[docs] def set_function_body(self, function_body: Callable): """Set the function body.""" self.function_body = function_body
"""Argument and return type of the function (when the function is not an overloaded one).""" @property def arguments(self) -> Tuple[Variable]: assert not self.is_overloaded return self.ftype.arguments @property def nr_arguments(self) -> int: assert not self.is_overloaded return self.ftype.nr_arguments @property def return_type(self) -> FunctionReturnType: assert not self.is_overloaded return self.ftype.return_type @property def is_generator_function(self) -> bool: return self.ftype.is_generator_function """When the function is overloaded, the following functions are used for get the "overridden calls" for each function type.""" @property def is_overloaded(self) -> bool: """Return True if the function is overloaded.""" return isinstance(self.ftype, OverloadedFunctionType)
[docs] def get_overridden_call(self, ftype_index: Optional[int] = None) -> Optional[Callable]: """Get the overridden call function.""" if isinstance(self.overridden_call, FunctionOverriddenCallList): assert ftype_index is not None return self.overridden_call[ftype_index] return self.overridden_call
[docs] def get_sub_function(self, ftype_index: int) -> 'Function': assert self.is_overloaded assert 0 <= ftype_index < self.ftype.nr_types return type(self)( self.name, self.ftype.types[ftype_index], self.get_overridden_call(ftype_index), resolved_from=FunctionResolvedFromRecord(self, ftype_index), function_body=self.function_body[ftype_index] if self.function_body is not None else None )
@cached_property def all_sub_functions(self) -> List['Function']: assert self.is_overloaded return [self.get_sub_function(i) for i in range(self.ftype.nr_types)]
[docs] @classmethod def from_function(cls, function: Callable, implementation: bool = True, sig: Optional[inspect.Signature] = None): """Create a function object from an actual Python function. Args: function: The function. implementation: Whether the function is an implementation. Defaults to True. sig: The signature of the function. Defaults to None. """ ftype = FunctionType.from_annotation(function, sig=sig) return cls(function.__name__, ftype, function_body=function if implementation else None)
[docs] def __call__(self, *args, **kwargs): if self.overridden_call is not None: if isinstance(self.ftype, OverloadedFunctionType): ftype_index, function_type, resolved_args = self.ftype.resolve_type_and_args(*args, **kwargs) return self.get_overridden_call(ftype_index)(*resolved_args) else: resolved_args = self.ftype.resolve_args(*args, **kwargs) return self.overridden_call(*resolved_args) if isinstance(self.ftype, OverloadedFunctionType): ftype_index, function_type, resolved_args = self.ftype.resolve_type_and_args(*args, **kwargs) function = Function( self.name + f'_{ftype_index}', function_type, overridden_call=None, # Must be none. resolved_from=FunctionResolvedFromRecord(self, ftype_index), ) else: resolved_args = self.ftype.resolve_args(*args, **kwargs) function = self from concepts.dsl.expression import FunctionApplicationExpression, cvt_expression_list return FunctionApplicationExpression(function, cvt_expression_list(resolved_args, function.ftype.argument_types))
def __str__(self): if self.is_derived and not self.is_overloaded: with FormatContext(type_format_cls=False).as_default(): if get_format_context().function_format_lambda: fmt = ''.join(['lam ' + str(n.name) + '.' for n in self.arguments]) with FormatContext(object_format_type=False).as_default(): fmt += str(self.derived_expression) else: fmt = 'def ' + self.name + '(' + ', '.join([str(x) for x in self.arguments]) + '): ' with FormatContext(object_format_type=False).as_default(): fmt += 'return ' + indent_text(str(self.derived_expression)).lstrip() else: if isinstance(self.ftype, OverloadedFunctionType): fmt = '\n'.join([f'{func_type}' for func_type in self.ftype.types]) if self.name is not None: fmt = re.sub(r'^' + re.escape(self.name) + ' (overloaded): ', '', fmt, flags=re.MULTILINE) fmt = self.name + ': ' + '\n' + indent_text(fmt) else: fmt = f'{self.name}{self.ftype}' return fmt __repr__ = jacinle.repr_from_str
[docs] def remap_arguments(self, remapping: List[int]) -> 'Function': """ Generate a new Function object with a different argument order. Specifically, remapping is a permutation. The i-th argument to the new function will be the remapping[i]-th argument in the old function. Args: remapping: The remapping. Returns: The new function. """ if isinstance(self.ftype, OverloadedFunctionType): raise NotImplementedError('Argument remapping for overloaded functions are not implemented.') new_argument_types = [self.ftype.argument_types[i] for i in remapping] def new_overridden_call(*args): remapped_args = [None for _ in range(len(args))] for i, arg in enumerate(args): remapped_args[remapping[i]] = arg return self(*remapped_args) return Function( self.name, FunctionType( new_argument_types, self.ftype.return_type, ), overridden_call=new_overridden_call, resolved_from=self.resolved_from )
[docs] def partial(self, *args, execute_fully_bound_functions=False, **kwargs) -> Union['Function', 'FunctionApplicationExpression']: if self.name == '__lambda__': new_name = '__lambda__' else: new_name = f'{self.name}_partial' new_overridden_call = None new_resolved_from = None if self.is_overloaded: with FunctionArgumentResolutionContext( check_missing=False, check_overloaded_ambiguity=False ).as_default(): types_and_arguments = self.ftype.resolve_type_and_args(*args, **kwargs) if not isinstance(types_and_arguments, OverloadedFunctionAmbiguousResolutions): ftype_index, function_type, resolved_args = types_and_arguments unmapped_arguments = [i for i, arg in enumerate(resolved_args) if arg is FunctionArgumentUnset] if len(unmapped_arguments) == 0: return self._apply_with_resolved_args(resolved_args, ftype_index, function_type) new_type = _gen_partial_function_type(function_type, unmapped_arguments) new_resolved_from = FunctionResolvedFromRecord(self, ftype_index) else: # Block BEGIN {{{ # If there is one specific resolution s.t. all variables are grounded, use it. all_grounded_resolutions = list() for ftype_index, function_type, resolved_args in types_and_arguments: unmapped_arguments = [i for i, arg in enumerate(resolved_args) if arg is FunctionArgumentUnset] if len(unmapped_arguments) == 0: all_grounded_resolutions.append((ftype_index, function_type, resolved_args)) if len(all_grounded_resolutions) == 1: ftype_index, function_type, resolved_args = all_grounded_resolutions[0] return self._apply_with_resolved_args(resolved_args, ftype_index, function_type) elif len(all_grounded_resolutions) > 1: with get_function_argument_resolution_context().exc(): fmt = 'Got ambiguous application of overloaded function{}.\n'.format( '' if self.name is None else ' ' + self.name) fmt += 'Candidates are:\n' for r in all_grounded_resolutions: fmt += indent_text(str(r[1])) + '\n' fmt += 'Invoked with arguments: {}.'.format(str(all_grounded_resolutions[0][2])) raise FunctionArgumentResolutionError(fmt) # }}} Block END. possible_resolution_ids = list() possible_resolutions = list() possible_overridden_calls = list() for ftype_index, function_type, resolved_args in types_and_arguments: unmapped_arguments = [i for i, arg in enumerate(resolved_args) if arg is FunctionArgumentUnset] new_subtype = _gen_partial_function_type( function_type, unmapped_arguments ) possible_resolution_ids.append(ftype_index) possible_resolutions.append(new_subtype) possible_overridden_calls.append(_gen_partial_overriden_call( new_subtype, resolved_args, self )) new_type = OverloadedFunctionType(possible_resolutions) new_resolved_from = FunctionResolvedFromRecord(self, tuple(possible_resolution_ids)) new_overridden_call = FunctionOverriddenCallList(possible_overridden_calls) else: with FunctionArgumentResolutionContext(check_missing=False).as_default(): resolved_args = self.ftype.resolve_args(*args, **kwargs) unmapped_arguments = [i for i, arg in enumerate(resolved_args) if arg is FunctionArgumentUnset] if execute_fully_bound_functions: if len(unmapped_arguments) == 0: return self._apply_with_resolved_args(resolved_args) new_type = _gen_partial_function_type(self.ftype, unmapped_arguments) new_overridden_call = _gen_partial_overriden_call( new_type, resolved_args, self ) return Function( new_name, new_type, overridden_call=new_overridden_call, resolved_from=new_resolved_from )
def _apply_with_resolved_args( self, resolved_args, resolved_ftype_id=None, resolved_function_type=None ): if self.overridden_call is not None: return self.get_overridden_call(resolved_ftype_id)(*resolved_args) if isinstance(self.ftype, OverloadedFunctionType): function = Function( self.name + f'_{resolved_ftype_id}', resolved_function_type, overridden_call=None, # Must be none. resolved_from=FunctionResolvedFromRecord(self, resolved_ftype_id) ) else: function = self from concepts.dsl.expression import FunctionApplicationExpression, cvt_expression_list return FunctionApplicationExpression(function, cvt_expression_list(resolved_args, function.ftype.argument_types))
def _gen_expression_from_overridden_call(ftype: FunctionType, overridden_call: Callable): from concepts.dsl.expression import ExpressionDefinitionContext ctx = ExpressionDefinitionContext() with ctx.as_default(): arguments = [ctx[arg] for arg in ftype.arguments] return overridden_call(*arguments) def _gen_partial_function_type(old_type: FunctionType, unmapped_arguments): new_argument_types = [old_type.argument_types[i] for i in unmapped_arguments] new_return_type = old_type.return_type new_argument_names = None if old_type.argument_names is not None: new_argument_names = [old_type.argument_names[i] for i in unmapped_arguments] new_type = FunctionType(new_argument_types, new_return_type, argument_names=new_argument_names) return new_type def _gen_partial_overriden_call(new_type, resolved_args, call): assert isinstance(new_type, FunctionType) def partial_overriden_call(*new_args, **new_kwargs): new_resolved_args = new_type.resolve_args(*new_args, **new_kwargs) new_full_args = resolved_args.copy() j = 0 for i in range(len(resolved_args)): if new_full_args[i] is FunctionArgumentUnset: new_full_args[i] = new_resolved_args[j] j += 1 return call(*new_full_args) return partial_overriden_call