#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : function_domain_search.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 12/10/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
"""An enumerative search algorithm to generate candidate functions and expressions in a simple function domain."""
import itertools
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Optional, Union, Iterable, Sequence, Tuple, List, Callable
from concepts.dsl.dsl_types import ConstantType, ValueType
from concepts.dsl.dsl_functions import FunctionType, Function
from concepts.dsl.dsl_domain import DSLDomainBase
from concepts.dsl.value import Value
from concepts.dsl.expression import ConstantExpression, FunctionApplicationExpression, VariableExpression
from concepts.dsl.function_domain import FunctionDomain
from concepts.dsl.executors.function_domain_executor import FunctionDomainExecutor
__all__ = [
'FunctionDomainExpressionSearchResult', 'FunctionDomainExpressionEnumerativeSearcher', 'gen_merge_functions',
'gen_expression_search_result_from_expressions',
'FunctionArgumentStat', 'stat_function', 'canonize_function_parameters',
'learn_expression_from_examples'
]
_Types = FunctionDomain.AllowedTypes
[docs]
@dataclass
class FunctionDomainExpressionSearchResult(object):
expression: Union[ConstantExpression, Function, FunctionApplicationExpression]
"""The expression that is enumerated."""
depth: int
"""The depth of the expression."""
nr_constant_arguments: int
"""The number of constant arguments in the expression."""
nr_variable_arguments: int
"""The number of variable arguments in the expression."""
nr_function_arguments: int
"""The number of function arguments in the expression."""
[docs]
class FunctionDomainExpressionEnumerativeSearcher(object):
"""An enumerator of expressions and functions for a function domain."""
[docs]
def __init__(self, domain: DSLDomainBase):
"""Initialize the searcher.
Args:
domain: the domain of the semantics.
"""
self.domain = domain
self._constant_type_cache = self._gen_constant_type_cache()
self._function_type_cache = self._gen_function_type_cache()
def _gen_constant_type_cache(self):
"""Generate a dictionary that maps types to constants."""
cache = defaultdict(list)
for const in self.domain.constants.values():
cache[const.dtype].append(const)
return cache
def _gen_function_type_cache(self):
"""Generate a dictionary that maps types to functions (by return types)."""
cache = defaultdict(list)
for func in self.domain.functions.values():
if func.is_overloaded:
for f in func.all_sub_functions:
cache[f.return_type].append(f)
else:
cache[func.ftype.return_type].append(func)
return cache
[docs]
def gen(
self,
return_type: Optional[Union[_Types, Tuple[_Types, ...], List[_Types]]] = None,
*,
max_depth: int = 3,
max_variable_arguments: int = 2,
max_constant_arguments: int = 1,
max_function_arguments: int = 0,
search_constants: bool = False,
hash_function: Optional[Callable[[Union[Function, FunctionApplicationExpression]], Any]] = None,
verbose: bool = False,
):
return self.gen_constant_expressions(return_type) + self.gen_function_application_expressions(
return_type, max_depth=max_depth, max_variable_arguments=max_variable_arguments,
max_constant_arguments=max_constant_arguments, max_function_arguments=max_function_arguments,
search_constants=search_constants, hash_function=hash_function, verbose=verbose
)
[docs]
def gen_constant_expressions(
self,
return_type: Optional[Union[_Types, Tuple[_Types, ...], List[_Types]]] = None
) -> List[FunctionDomainExpressionSearchResult]:
"""Generate constant expressions of a set of given types.
Args:
return_type: the return type of the expressions. If None, all types are allowed.
It can be a single type, a tuple of types, or a list of types.
Returns:
A list of constant expressions.
"""
constants = list()
for c in self.domain.constants.values():
if (
return_type is None or
(isinstance(return_type, ConstantType) and c.dtype == return_type) or
(isinstance(return_type, (tuple, list)) and c.dtype in return_type)
):
constants.append(FunctionDomainExpressionSearchResult(
ConstantExpression(c), 1, 0, 0, 0
))
return constants
[docs]
def gen_function_application_expressions(
self,
return_type: Optional[Union[_Types, Tuple[_Types, ...], List[_Types]]] = None,
*,
max_depth: int = 3,
max_variable_arguments: int = 2,
max_constant_arguments: int = 1,
max_function_arguments: int = 0,
search_constants: bool = False,
hash_function: Optional[Callable[[Union[Function, FunctionApplicationExpression]], Any]] = None,
verbose: bool = False,
) -> List[FunctionDomainExpressionSearchResult]:
"""Generate functions and function application expressions of a set of given types.
Args:
max_depth: the maximum depth of the expressions.
return_type: the return type of the expressions. If None, all types are allowed.
It can be a single type, a tuple of types, or a list of types.
max_variable_arguments: the maximum number of variable arguments of the functions.
max_constant_arguments: the maximum number of constant arguments of the functions.
Note that when ``search_constants`` is True, this parameter corresponds to the maximum number of
constant arguments bound to the function / function application expression.
max_function_arguments: the maximum number of arguments of the functions.
search_constants: whether to search for constants.
verbose: whether to print the search progress.
Returns:
A list of function application expressions.
"""
ftcache = self._function_type_cache
current = {i: defaultdict(list) for i in range(1, max_depth + 1)}
def iter_ddl_values(dd):
for v in dd.values():
yield from v
for rtype, functions in ftcache.items():
for f in functions:
current[1][f.ftype.return_type].append(gen_merge_functions(f))
if max_function_arguments > 0:
for ret_type in self.domain.types.values():
current[1][ret_type].extend(self._gen_function_primitives(ret_type, max_function_arguments))
if verbose:
print('-' * 20, 'Depth', 1, '-' * 100)
for rtype, functions in current[1].items():
for f in functions:
print(rtype, '\t', f)
for depth in range(2, max_depth + 1):
for depth1 in range(1, depth):
for f1 in list(iter_ddl_values(current[depth1])):
for i in range(f1.nr_arguments):
for depth2 in range(1, depth - depth1 + 1):
for f2 in current[depth2][f1.ftype.argument_types[i]]:
current[depth][f1.ftype.return_type].append(
gen_merge_functions(f1, i, f2)
)
if verbose:
print('-' * 20, 'Depth', depth, '-' * 100)
for rtype, functions in current[depth].items():
for f in functions:
print(rtype, '\t', f)
expressions = list()
for depth, vs in current.items():
for f in iter_ddl_values(vs):
if isinstance(f, Function):
stat = stat_function(f)
if stat.nr_constant_arguments <= max_constant_arguments and \
stat.nr_variable_arguments <= max_variable_arguments and \
stat.nr_function_arguments <= max_function_arguments:
expressions.append(FunctionDomainExpressionSearchResult(
canonize_function_parameters(f, ignore_permutation=True),
depth, stat.nr_constant_arguments, stat.nr_variable_arguments, stat.nr_function_arguments
))
else:
expressions.append(FunctionDomainExpressionSearchResult(
f, depth, 0, 0, 0
))
expressions = self._unique_function_expressions(expressions, hash_function=hash_function)
if search_constants:
expressions = self._bind_constants_to_expressions(expressions)
if return_type is not None:
output_expressions = list()
for result in expressions:
if _match_return_type(result.expression, return_type):
output_expressions.append(result)
expressions = output_expressions
return expressions
def _gen_function_primitives(self, ret_type, nr_function_arguments):
def function_call(func, *args):
return func(*args)
def gen():
types = tuple(self.domain.types.values())
for repeat in range(1, nr_function_arguments + 1):
for arg_types in itertools.product(types, repeat=repeat):
yield Function(
'__lambda__',
FunctionType(
[FunctionType(arg_types, ret_type), ] + list(arg_types),
ret_type
),
overridden_call=function_call,
)
return tuple(gen())
def _unique_function_expressions(
self,
functions: List[FunctionDomainExpressionSearchResult],
hash_function: Optional[Callable[[Union[Function, FunctionApplicationExpression]], Any]] = None
) -> List[FunctionDomainExpressionSearchResult]:
"""Return a list of unique functions and function application expressions.
Args:
functions: a list of input expressions.
hash_function: a custom function that generates a hash for a function or function application expression.
If None, the default ``str`` function is used.
Returns:
A list of expressions without duplicates.
"""
if hash_function is None:
hash_function = str
unique_functions = list()
unique_function_hashes = set()
for r in functions:
h = hash_function(r.expression)
if h not in unique_function_hashes:
unique_functions.append(r)
unique_function_hashes.add(h)
return unique_functions
def _bind_constants_to_expressions(
self,
functions: List[FunctionDomainExpressionSearchResult],
) -> List[FunctionDomainExpressionSearchResult]:
output_functions = list()
for result in functions:
if not isinstance(result.expression, Function):
output_functions.append(result)
continue
f = result.expression
constant_denotations, constant_types = list(), list()
for index, argument_type in enumerate(f.ftype.argument_types):
if isinstance(argument_type, ConstantType):
constant_denotations.append(f'#{index}')
constant_types.append(argument_type)
for const in itertools.product(*[
self._constant_type_cache[t] for t in constant_types
]):
const_mapping = {k: ConstantExpression(v) for k, v in zip(constant_denotations, const)}
partial_func = f.partial(**const_mapping, execute_fully_bound_functions=True)
new_r = partial_func
output_functions.append(FunctionDomainExpressionSearchResult(
new_r,
result.depth,
result.nr_constant_arguments, result.nr_variable_arguments, result.nr_function_arguments
))
return output_functions
def _match_return_type(f: Union[Function, FunctionApplicationExpression], return_types: Union[_Types, Tuple[_Types, ...], List[_Types]]) -> bool:
"""Check if the type of a function or a function application expression matches a set of given return type."""
if not isinstance(return_types, (tuple, list)):
return_types = (return_types, )
if isinstance(f, Function):
for return_type in return_types:
if f.ftype.typename == return_type.typename:
return True
elif isinstance(f, FunctionApplicationExpression):
return f.return_type in return_types
else:
raise TypeError(f'Expected Function or FunctionApplicationExpression, got {type(f)}.')
[docs]
def gen_merge_functions(f1: Function, arg_index=None, f2: Optional[Function] = None) -> Function:
"""Generate merge functions. Specifically, given two functions f1 and f2,
this function generates a new function. For example, given f1(x, y) and f2(z),
with arg_index = 1, this function generates:
.. code-block:: python
def merged(x, z):
return f1(x, f2(z))
That is, we first apply f2 with the input arguments, and then apply f1 with the
rest of the arguments and the output of f2.
A special case is when f2 is None. In this case, we generate a function that
is simply a wrapper of f1.
Args:
f1: the first function.
arg_index: the index of the argument of f1 that we want to merge with f2.
f2: the second function.
Returns:
A new function.
"""
if arg_index is None:
return Function(
'__lambda__',
FunctionType(f1.ftype.argument_types, f1.ftype.return_type), overridden_call=f1
)
else:
f1_arg_types = f1.ftype.argument_types
f2_arg_types = f2.ftype.argument_types
arg_types = f2_arg_types + f1_arg_types[:arg_index] + f1_arg_types[arg_index + 1:]
def new_function_call(*args):
f2_args = args[:f2.nr_arguments]
f1_args = list(args[f2.nr_arguments:])
f2_ret = f2(*f2_args)
f1_args.insert(arg_index, f2_ret)
return f1(*f1_args)
return Function(
'__lambda__',
FunctionType(arg_types, f1.ftype.return_type),
overridden_call=new_function_call
)
[docs]
def gen_expression_search_result_from_expressions(expressions: Iterable[Union[ConstantExpression, Function, FunctionApplicationExpression]]) -> List[FunctionDomainExpressionSearchResult]:
"""Generate a list of FunctionDomainExpressionSearchResult from a list of expressions."""
results = list()
for expression in expressions:
if isinstance(expression, (ConstantExpression, FunctionApplicationExpression)):
results.append(FunctionDomainExpressionSearchResult(expression, 0, 0, 0, 0))
else:
stat = stat_function(expression)
results.append(FunctionDomainExpressionSearchResult(expression, 0, stat.nr_constant_arguments, stat.nr_variable_arguments, stat.nr_function_arguments))
return results
[docs]
@dataclass
class FunctionArgumentStat(object):
"""Statistics for the argument list of a function."""
nr_constant_arguments: int
nr_variable_arguments: int
nr_function_arguments: int
[docs]
def stat_function(f: Function) -> FunctionArgumentStat:
"""Return the number of constants, variables, and functions in a function."""
nr_variable_arguments = 0
nr_constant_arguments = 0
nr_function_arguments = 0
for arg_t in f.ftype.argument_types:
if isinstance(arg_t, ConstantType):
nr_constant_arguments += 1
elif isinstance(arg_t, FunctionType):
nr_function_arguments += 1
nr_variable_arguments += 1 # function arg is also variable arg
else:
nr_variable_arguments += 1
return FunctionArgumentStat(
nr_constant_arguments=nr_constant_arguments,
nr_variable_arguments=nr_variable_arguments,
nr_function_arguments=nr_function_arguments
)
[docs]
def canonize_function_parameters(
f: Union[ConstantExpression, FunctionApplicationExpression, Function],
ignore_permutation: bool = False
) -> Union[ConstantExpression, FunctionApplicationExpression, Function]:
"""Return a new function object with argument reordered: functions, variables, constants.
Args:
f: the function to be canonized. If the function is a function application expression or a constant expression,
this function returns the same object.
Returns:
A new function object with argument reordered: functions, variables, constants.
"""
if isinstance(f, (ConstantExpression, FunctionApplicationExpression)):
return f
assert isinstance(f, Function)
# NB(Jiayuan Mao @ 2022/12/10): if the function is a primitive function, just return it.
if f.overridden_call is None:
return f
assert not f.is_overloaded
assert isinstance(f.derived_expression, FunctionApplicationExpression)
function_args, variable_args, constant_args = list(), list(), list()
if ignore_permutation:
def walk(node: FunctionApplicationExpression):
if isinstance(node.function, VariableExpression):
new_index = f.ftype.argument_names.index(node.function.name)
if new_index not in function_args:
function_args.append(new_index)
for arg in node.arguments:
if isinstance(arg, VariableExpression):
new_index = f.ftype.argument_names.index(arg.name)
if isinstance(arg.dtype, ConstantType):
if new_index not in constant_args:
constant_args.append(new_index)
elif isinstance(arg.dtype, ValueType):
if new_index not in variable_args:
variable_args.append(new_index)
else:
raise TypeError('Unknown type for anonymous argument #{}, type = {}.'.format(arg.name, arg.dtype))
elif isinstance(arg, FunctionApplicationExpression):
walk(arg)
else:
raise TypeError('Unknown type for anonymous argument type {}.'.format(type(arg)))
walk(f.derived_expression)
else:
for arg_index, arg_t in enumerate(f.ftype.argument_types):
if isinstance(arg_t, FunctionType):
function_args.append(arg_index)
elif isinstance(arg_t, ConstantType):
constant_args.append(arg_index)
elif isinstance(arg_t, ValueType):
variable_args.append(arg_index)
else:
raise TypeError('Unknown type for argument {}, type = {}.'.format(arg_index, arg_t))
new_argument_mapping = list(function_args) + list(variable_args) + list(constant_args)
f = f.remap_arguments(new_argument_mapping)
return f
[docs]
def learn_expression_from_examples(
domain: FunctionDomain,
executor: FunctionDomainExecutor,
input_output: Sequence[Tuple[Sequence[Value], Value, Any]],
criterion: Callable[[Value, Value], bool],
candidate_expressions: Optional[Iterable[FunctionDomainExpressionSearchResult]] = None,
max_depth: int = 3,
max_function_arguments: int = 0,
search_constants: bool = True,
) -> Function:
"""Learn a function from examples.
Args:
domain: the function domain.
executor: the executor of the function domain.
input_output: a sequence of (input, output, grounding) pairs.
criterion: a function that takes two values and returns True if they are equal.
candidate_expressions: a sequence of candidate functions. If None, we will generate
candidate functions automatically.
max_depth: the maximum depth of the function. Only used when ``candidate_expressions`` is None.
max_function_arguments: the maximum number of arguments of a function. Only used when
``candidate_expressions`` is None.
search_constants: whether to search for constants. Only used when ``candidate_expressions`` is None.
Returns:
A function.
"""
assert len(input_output) > 0, 'No input-output pairs are given.'
sample_input, sample_output, _ = input_output[0]
if len(sample_input) == 0:
target_type = sample_output.dtype
else:
target_type = FunctionType([v.dtype for v in sample_input], sample_output.dtype)
if candidate_expressions is None:
domain = FunctionDomainExpressionEnumerativeSearcher(domain)
candidate_expressions = domain.gen_function_application_expressions(
target_type,
max_depth=max_depth,
max_function_arguments=max_function_arguments,
search_constants=search_constants
)
if isinstance(target_type, FunctionType):
def score_function(result: FunctionDomainExpressionSearchResult) -> float:
f = result.expression
if isinstance(f, FunctionApplicationExpression):
return -1
return sum(
criterion(executor.execute_function(f, *input, grounding=grounding), output)
for input, output, grounding in input_output
)
else:
def score_function(result: FunctionDomainExpressionSearchResult) -> float:
f = result.expression
if isinstance(f, Function):
return -1
return sum(
criterion(executor.execute(f, grounding), output)
for input, output, grounding in input_output
)
return max(candidate_expressions, key=score_function).expression