Tutorial 1.4: Use Enumerative Search to Learn a Function#
[1]:
# From tutorial/1-dsl/1-types-and-functions
from concepts.dsl.dsl_types import ValueType, ConstantType, BOOL, FLOAT32, INT64, VectorValueType, FormatContext
from concepts.dsl.dsl_functions import Function, FunctionTyping
from concepts.dsl.function_domain import FunctionDomain
t_item = ValueType('item')
t_item_set = ValueType('item_set')
t_concept_name = ConstantType('concept_name')
t_shape = ValueType('shape')
t_color = ValueType('color')
t_size = VectorValueType(FLOAT32, 3, alias='size')
t_int = INT64
domain = FunctionDomain()
domain.define_type(t_item)
domain.define_type(t_item_set)
domain.define_type(t_concept_name)
domain.define_type(t_color)
domain.define_type(t_shape)
domain.define_type(t_size)
domain.define_function(Function('scene', FunctionTyping[t_item_set]()))
domain.define_function(Function('filter_color', FunctionTyping[t_item_set](t_item_set, t_concept_name)))
domain.define_function(Function('filter_shape', FunctionTyping[t_item_set](t_item_set, t_concept_name)))
domain.define_function(Function('unique', FunctionTyping[t_item](t_item_set)))
domain.define_function(Function('color_of', FunctionTyping[t_color](t_item)))
domain.define_function(Function('shape_of', FunctionTyping[t_shape](t_item)))
domain.define_function(Function('size_of', FunctionTyping[t_size](t_item)))
domain.define_function(Function('same_color', FunctionTyping[BOOL](t_color, t_color)))
domain.define_function(Function('same_shape', FunctionTyping[BOOL](t_shape, t_shape)))
domain.define_function(Function('same_size', FunctionTyping[BOOL](t_size, t_size)))
domain.define_function(Function('count', FunctionTyping[t_int](t_item_set)))
domain.define_const(t_concept_name, 'box')
domain.define_const(t_concept_name, 'sphere')
domain.define_const(t_concept_name, 'red')
domain.define_const(t_concept_name, 'blue')
domain.define_const(t_concept_name, 'green')
[2]:
# From tutorial/1-dsl/2-execution
from dataclasses import dataclass, field
from typing import Tuple, List
from concepts.dsl.executors.function_domain_executor import FunctionDomainExecutor
@dataclass
class Item(object):
color: str
shape: str
size: Tuple[float, float, float]
@dataclass
class Scene(object):
items: List[Item]
class Executor(FunctionDomainExecutor):
def scene(self):
return self.grounding.items
def filter_color(self, inputs, color_name):
return [o for o in inputs if o.color == color_name]
def filter_shape(self, inputs, shape_name):
return [o for o in inputs if o.shape == shape_name]
def unique(self, inputs):
assert len(inputs) == 1
return inputs[0]
def color_of(self, obj):
return obj.color
def shape_of(self, obj):
return obj.shape
def size_of(self, obj):
return obj.size
def same_color(self, c1, c2):
return c1 == c2
def same_shape(self, s1, s2):
return s1 == s2
def same_size(self, z1, z2):
return all(abs(sz1 - sz2) < 0.1 for sz1, sz2 in zip(z1, z2))
def count(self, inputs):
return len(inputs)
executor = Executor(domain)
15 16:55:46 Function scene automatically registered.
15 16:55:46 Function filter_color automatically registered.
15 16:55:46 Function filter_shape automatically registered.
15 16:55:46 Function unique automatically registered.
15 16:55:46 Function color_of automatically registered.
15 16:55:46 Function shape_of automatically registered.
15 16:55:46 Function size_of automatically registered.
15 16:55:46 Function same_color automatically registered.
15 16:55:46 Function same_shape automatically registered.
15 16:55:46 Function same_size automatically registered.
15 16:55:46 Function count automatically registered.
[3]:
scene1 = Scene([
Item('red', 'box', (1, 1, 1)),
Item('blue', 'box', (1, 1, 1)),
Item('green', 'box', (2, 2, 2))
])
scene2 = Scene([
Item('red', 'box', (1, 1, 1)),
Item('red', 'box', (1, 1, 1)),
])
[4]:
target_expr = domain.f_count(domain.f_filter_color(domain.f_scene(), 'red'))
print(target_expr)
print('scene1:', executor.execute(target_expr, grounding=scene1))
print('scene2:', executor.execute(target_expr, grounding=scene2))
count(filter_color(scene(), V(red, dtype=concept_name)))
scene1: V(1, dtype=int64)
scene2: V(2, dtype=int64)
[5]:
from concepts.dsl.learning.function_domain_search import FunctionDomainExpressionEnumerativeSearcher
[6]:
enumerator = FunctionDomainExpressionEnumerativeSearcher(domain)
candidate_expressions = enumerator.gen_function_application_expressions(
return_type=t_int,
max_depth=3,
search_constants=True
)
with FormatContext(function_format_lambda=True).as_default():
for x in candidate_expressions:
print(x.expression)
print(f'In total: {len(candidate_expressions)} candidate expressions.')
count(scene())
count(filter_color(scene(), V(box, dtype=concept_name)))
count(filter_color(scene(), V(sphere, dtype=concept_name)))
count(filter_color(scene(), V(red, dtype=concept_name)))
count(filter_color(scene(), V(blue, dtype=concept_name)))
count(filter_color(scene(), V(green, dtype=concept_name)))
count(filter_shape(scene(), V(box, dtype=concept_name)))
count(filter_shape(scene(), V(sphere, dtype=concept_name)))
count(filter_shape(scene(), V(red, dtype=concept_name)))
count(filter_shape(scene(), V(blue, dtype=concept_name)))
count(filter_shape(scene(), V(green, dtype=concept_name)))
In total: 11 candidate expressions.
[7]:
from concepts.dsl.learning.function_domain_search import learn_expression_from_examples
[8]:
io_examples = [
([], executor.execute(target_expr, grounding=scene1), scene1),
([], executor.execute(target_expr, grounding=scene2), scene2)
]
[9]:
learn_expression_from_examples(
domain, executor,
input_output=io_examples,
criterion=lambda x, y: x.value == y.value,
candidate_expressions=candidate_expressions
)
[9]:
FunctionApplicationExpression<count(filter_color(scene(), V(red, dtype=concept_name)))>
[10]:
learn_expression_from_examples(
domain, executor,
input_output=io_examples,
criterion=lambda x, y: x.value == y.value,
candidate_expressions=None # The algorithm will automatically infer the target type.
)
[10]:
FunctionApplicationExpression<count(filter_color(scene(), V(red, dtype=concept_name)))>