import torch
from concepts.dsl.dsl_types import BOOL, INT64, FLOAT32, ObjectType, VectorValueType, Variable
from concepts.dsl.dsl_functions import Function, FunctionType
from concepts.dsl.function_domain import FunctionDomain
from concepts.dsl.expression import VariableExpression, FunctionApplicationExpression
from concepts.dsl.tensor_value import TensorValue, from_tensor
from concepts.dsl.tensor_state import NamedObjectTensorState
from concepts.dsl.executors.tensor_value_executor import FunctionDomainTensorValueExecutor
# See the documentation for TensorState for more details.
domain = FunctionDomain()
# Define an object type `person`.
domain.define_type(ObjectType('person'))
# Define a state variable `is_friend` with type `person x person -> bool`.
domain.define_function(Function('is_friend', FunctionType([ObjectType('person'), ObjectType('person')], BOOL)))
state = NamedObjectTensorState({
'is_friend': TensorValue(BOOL, ['x', 'y'], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 0, 1]], dtype=torch.bool))
}, object_names={
'Alice': ObjectType('person'),
'Bob': ObjectType('person'),
'Charlie': ObjectType('person'),
})
executor = FunctionDomainTensorValueExecutor(domain)