#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : crow_planner.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 11/09/2023
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
from typing import Any, Optional, Union, Iterable, Iterator, Tuple, List, Dict
from dataclasses import dataclass
import jacinle
import torch
from concepts.dsl.dsl_types import UnnamedPlaceholder, Variable
from concepts.dsl.expression import VariableExpression, ConstantExpression, ValueOutputExpression, ListExpansionExpression, is_and_expr, iter_exprs
from concepts.dsl.expression_visitor import IdentityExpressionVisitor
from concepts.dsl.constraint import OptimisticValue, GroupConstraint, ConstraintSatisfactionProblem, AssignmentDict
from concepts.dsl.value import ListValue
from concepts.dsl.tensor_value import TensorValue
from concepts.pdsketch.operator import OperatorApplicationExpression, OperatorApplier
from concepts.pdsketch.executor import PDSketchExecutor, PDSketchSGC
from concepts.pdsketch.domain import State
from concepts.pdsketch.regression_rule import CSPCommitFlag, AchieveExpression, FindExpression, SubgoalSerializability
from concepts.pdsketch.planners.optimistic_search_bilevel_utils import OptimisticSearchSymbolicPlan
from concepts.pdsketch.planners.optimistic_search import ground_actions
from concepts.pdsketch.csp_solvers.dpll_sampling import csp_dpll_sampling_solve
from concepts.pdsketch.crow.regression_utils import surface_fol_downcast, ground_fol_expression
from concepts.pdsketch.crow.regression_utils import evaluate_bool_scalar_expression, ground_operator_application_expression, gen_applicable_regression_rules, ApplicableRegressionRuleGroup
from concepts.pdsketch.crow.crow_state import PartiallyOrderedPlan, TotallyOrderedPlan
[docs]@dataclass
class SearchState(object):
pass
[docs]def crow_recursive(
executor: PDSketchExecutor, state: State, goal_expr: Union[str, ValueOutputExpression], *,
is_goal_serialized: bool = True,
enable_reordering: bool = False,
enable_csp: bool = True,
max_actions: int = 10,
max_csp_branching_factor: int = 5, max_beam_size: int = 20,
allow_empty_plan_for_optimistic_goal: bool = False,
verbose: bool = True
) -> Tuple[Iterable[Any], Dict[str, Any]]:
"""Compositional Regression and Optimization Wayfinder.
Args:
executor: the executor.
state: the initial state.
goal_expr: the goal expression.
is_goal_serialized: whether the goal is serialized already. Otherwise, it will be treated as a conjunction.
enable_reordering: whether to enable reordering of subgoals in regression rules.
max_actions: the maximum number of actions in a plan.
verbose: whether to print verbose information.
Returns:
A list of plans. Each plan is a tuple of (actions, csp, initial_state, final_state).
"""
if isinstance(goal_expr, str):
goal_expr = executor.parse(goal_expr)
search_cache = dict()
search_stat = {'nr_expanded_nodes': 0}
# NB(Jiayuan Mao @ 2023/09/09): the cache only works for previous_actions == [].
# That is, we only cache the search results that start from the initial state.
def return_with_cache(goal_set, previous_actions, rv):
if len(previous_actions) == 0:
goal_str = goal_set.gen_string()
if goal_str not in search_cache:
search_cache[goal_str] = rv
return rv
def try_retrieve_cache(goal_set, previous_actions):
if len(previous_actions) == 0:
goal_str = goal_set.gen_string()
if goal_str in search_cache:
return search_cache[goal_str]
return None
@jacinle.log_function(verbose=False)
def dfs(
s: State, g: PartiallyOrderedPlan, c: Tuple[ValueOutputExpression, ...],
csp: Optional[ConstraintSatisfactionProblem], previous_actions: List[OperatorApplier],
return_all: bool = False,
tail_csp_solve: bool = False,
nr_high_level_actions: int = 0
) -> Iterator[Tuple[State, ConstraintSatisfactionProblem, List[OperatorApplier]]]:
"""Depth-first search for all possible plans.
Args:
s: the current state.
g: the current goal.
c: the list of constraints to maintain.
csp: the current constraint satisfaction problem.
previous_actions: the previous actions.
return_all: whether to return all possible plans. If False, only return the first plan found.
tail_csp_solve: whether to solve the CSP after the expansion of the current goal.
nr_high_level_actions: the number of high-level actions.
Returns:
a list of plans. Each plan is a tuple of (final_state, csp, actions).
"""
if verbose:
jacinle.log_function.print('Current goal', g, f'return_all={return_all}', f'previous_actions={previous_actions}')
if enable_csp and not tail_csp_solve:
return_all = True # Ignore the STRONG and ORDER serializability flags.
if (rv := try_retrieve_cache(g, previous_actions)) is not None:
return rv
all_possible_plans = list()
flatten_goals = list(g.iter_goals())
if not _has_optimistic_constant_expression(*flatten_goals) and allow_empty_plan_for_optimistic_goal:
"""If the current goal contains no optimistic constant, we may directly solve the CSP."""
rv, is_optimistic, new_csp = evaluate_bool_scalar_expression(executor, flatten_goals, s, dict(), csp, csp_note='goal_test')
if rv:
all_possible_plans.append((s, new_csp, previous_actions))
if not is_optimistic: # If there is no optimistic value, we can stop the search from here.
# NB(Jiayuan Mao @ 2023/09/11): note that even if `return_all` is True, we still return here.
# This corresponds to an early stopping behavior that defines the space of all possible plans.
return return_with_cache(g, previous_actions, all_possible_plans)
if nr_high_level_actions > max_actions:
return return_with_cache(g, previous_actions, all_possible_plans)
search_stat['nr_expanded_nodes'] += 1
candidate_regression_rules = gen_applicable_regression_rules(executor, s, g, c)
if _len_candidate_regression_rules(candidate_regression_rules) == 0:
return return_with_cache(g, previous_actions, all_possible_plans)
some_rule_success = False
for chain_index, subgoal_index, this_candidate_regression_rules in candidate_regression_rules:
other_goals = g.exclude(chain_index, subgoal_index)
cur_goal = g.chains[chain_index].sequence[subgoal_index]
if verbose:
jacinle.log_function.print('Now trying to excluding goal', cur_goal)
if len(other_goals) == 0:
other_goals_plans = [(s, csp, previous_actions)]
else:
# TODO(Jiayuan Mao @ 2023/09/09): change this list to an actual generator call.
need_return_all = any(rule.max_rule_prefix_length > 0 for rule, _ in this_candidate_regression_rules) or enable_csp
other_goals_plans = list(dfs(s, other_goals, c, csp, previous_actions, nr_high_level_actions=nr_high_level_actions, return_all=need_return_all))
for cur_state, cur_csp, cur_actions in other_goals_plans:
rv, is_optimistic, new_csp = evaluate_bool_scalar_expression(executor, [cur_goal], cur_state, dict(), cur_csp, csp_note='goal_test_shortcut')
if rv:
all_possible_plans.append((cur_state, new_csp, cur_actions))
if not is_optimistic:
# NB(Jiayuan Mao @ 2023/09/11): another place where we stop the search and ignores the `return_all` flag.
continue
if len(this_candidate_regression_rules) == 0:
continue
if len(other_goals_plans) == 0:
continue
if len(other_goals) == 0:
max_prefix_length = 0
else:
if not enable_reordering:
max_prefix_length = 0
else:
# TODO(Jiayuan Mao @ 2023/11/27): set this number to a very large number, and then use a flag to control the applicability of the reorderings.
max_prefix_length = max(rule.max_reorder_prefix_length for rule, _ in this_candidate_regression_rules)
prefix_stop_mark = dict()
for prefix_length in range(max_prefix_length + 1):
for regression_rule_index, (rule, bounded_variables) in enumerate(this_candidate_regression_rules):
if prefix_length > rule.max_reorder_prefix_length:
continue
if regression_rule_index in prefix_stop_mark and prefix_stop_mark[regression_rule_index]:
continue
if verbose:
jacinle.log_function.print('Applying rule', rule, 'for', cur_goal, 'and prefix length', prefix_length, 'goal is', g)
if prefix_length == 0:
# TODO(Jiayuan Mao @ 2023/11/19): there is a bug for max_rule_prefix_length when there is a list expansion.
if rule.max_rule_prefix_length > 0:
previous_possible_branches = other_goals_plans
else:
previous_possible_branches = [other_goals_plans[0]]
else:
raise NotImplementedError('Reordering is not implemented yet.')
# TODO(Jiayuan Mao @ 2023/11/24): implement this.
# cur_other_goals = other_goals.add_chain(grounded_subgoals[:prefix_length])
# previous_possible_branches = list(dfs(s, cur_other_goals, c, cur_csp, previous_actions, nr_high_level_actions=nr_high_level_actions + 1, return_all=rule.max_rule_prefix_length > 0))
if len(previous_possible_branches) == 0:
if verbose:
jacinle.log_function.print('Prefix planning failed!!! Stop.')
# If it's not possible to achieve the subset of goals, then it's not possible to achieve the whole goal.
# Therefore, this is a break, not a continue.
prefix_stop_mark[regression_rule_index] = True
continue
for prev_state, prev_csp, prev_actions in previous_possible_branches:
# construct the new csp and the sequence of grounded subgoals.
grounded_subgoals = list()
placeholder_csp = ConstraintSatisfactionProblem()
placeholder_bounded_variables = bounded_variables.copy()
for i, item in enumerate(rule.body):
if isinstance(item, AchieveExpression):
grounded_subgoals.append(AchieveExpression(ground_fol_expression(item.goal, placeholder_bounded_variables), maintains=[]))
elif isinstance(item, FindExpression):
for variable in item.variables:
placeholder_bounded_variables[variable] = _create_find_expression_variable(variable, csp=placeholder_csp, bounded_variables=placeholder_bounded_variables)
grounded_subgoals.append(FindExpression([], ground_fol_expression(item.goal, placeholder_bounded_variables)))
elif isinstance(item, OperatorApplicationExpression):
cur_action = ground_operator_application_expression(item, placeholder_bounded_variables, csp=placeholder_csp, add_csp_variables=False)
grounded_subgoals.append(cur_action)
elif isinstance(item, ListExpansionExpression):
subgoals = executor.execute(item.expression, s, placeholder_bounded_variables, sgc=PDSketchSGC(s, g, c))
grounded_subgoals.extend(subgoals.sequence)
elif isinstance(item, CSPCommitFlag):
grounded_subgoals.append(item)
else:
raise ValueError(f'Unknown item type {type(item)} in rule {item}.')
possible_branches = [(prev_state, prev_csp, prev_actions, {})]
for i in range(prefix_length, len(grounded_subgoals)):
item = grounded_subgoals[i]
next_possible_branches = list()
if isinstance(item, AchieveExpression):
if not enable_csp and item.serializability is SubgoalSerializability.STRONG and len(possible_branches) > 1:
possible_branches = [min(possible_branches, key=lambda x: len(x[2]))]
need_return_all = enable_csp or i < rule.max_rule_prefix_length
for branch_index, (cur_state, cur_csp, cur_actions, cur_csp_variable_mapping) in enumerate(possible_branches):
# prev_next_possible_branches_length = len(next_possible_branches)
if isinstance(item, AchieveExpression):
new_csp = cur_csp.clone() if cur_csp is not None else None
subgoal, new_csp_variable_mapping = _map_csp_placeholder_goal(item.goal, new_csp, placeholder_csp, placeholder_bounded_variables, cur_csp_variable_mapping)
next_possible_branches.extend([(*x, new_csp_variable_mapping) for x in dfs(
cur_state, PartiallyOrderedPlan.from_single_goal(subgoal), c + item.maintains,
new_csp, cur_actions,
return_all=need_return_all, nr_high_level_actions=nr_high_level_actions + 1
)])
elif isinstance(item, FindExpression):
if cur_csp is None:
raise RuntimeError('FindExpression must be used with a CSP.')
new_csp = cur_csp.clone()
subgoal, new_csp_variable_mapping = _map_csp_placeholder_goal(item.goal, new_csp, placeholder_csp, placeholder_bounded_variables, cur_csp_variable_mapping)
with new_csp.with_group(subgoal) as group:
rv = executor.execute(subgoal, cur_state, {}, csp=new_csp).item()
if isinstance(rv, OptimisticValue):
new_csp.add_equal_constraint(rv)
_mark_solver(executor, state, bounded_variables, group)
next_possible_branches.append((cur_state, new_csp, cur_actions, new_csp_variable_mapping))
elif isinstance(item, OperatorApplier):
# TODO(Jiayuan Mao @ 2023/09/11): vectorize this operation, probably only useful when `return_all` is True.
new_csp = cur_csp.clone() if cur_csp is not None else None
subaction, new_csp_variable_mapping = _map_csp_placeholder_action(item, new_csp, placeholder_csp, placeholder_bounded_variables, cur_csp_variable_mapping)
succ, new_state = executor.apply(subaction, cur_state, csp=new_csp, clone=True, action_index=len(cur_actions))
if succ:
next_possible_branches.append((new_state, new_csp, cur_actions + [subaction], new_csp_variable_mapping))
elif isinstance(item, CSPCommitFlag):
assignments = csp_dpll_sampling_solve(executor, cur_csp)
if assignments is not None:
new_state = _map_csp_variable_state(cur_state, cur_csp, assignments)
new_csp = ConstraintSatisfactionProblem()
new_actions = ground_actions(executor, cur_actions, assignments)
new_csp_variable_mapping = _map_csp_variable_mapping(cur_csp_variable_mapping, csp, assignments)
next_possible_branches.append((new_state, new_csp, new_actions, new_csp_variable_mapping))
# TODO(Jiayuan Mao @ 2023/11/27): okay we need to implement some kind of tracking of "bounded_variables."
# This need to be done by tracking some kind of mapping for optimistic variables in "grounded_subgoals."
else:
raise TypeError(f'Unknown item: {item}')
# jacinle.log_function.print(f'Branch {branch_index + 1} of {len(possible_branches)} for {item} has {len(next_possible_branches) - prev_next_possible_branches_length} branches.')
possible_branches = next_possible_branches
# jacinle.log_function.print(f'Finished search subgoal {i + 1} of {len(grounded_subgoals)}: {item}. Possible branches (length={len(possible_branches)}):')
# for x in possible_branches:
# jacinle.log_function.print(jacinle.indent_text(str(x[2])))
# all_possible_plans.extend(possible_branches)
found_plan = False
# TODO(Jiayuan Mao @ 2023/09/11): implement this via maintains checking.
for cur_state, cur_csp, actions, _ in possible_branches:
rv, is_optimistic, new_csp = evaluate_bool_scalar_expression(executor, flatten_goals, cur_state, dict(), csp=cur_csp, csp_note=f'subgoal_test: {"; ".join([str(x) for x in flatten_goals])}')
if rv:
if verbose:
jacinle.log_function.print('Found a plan', [str(x) for x in actions], 'for goal', g)
if is_optimistic and tail_csp_solve:
assignments = csp_dpll_sampling_solve(executor, new_csp, verbose=True)
if assignments is not None:
all_possible_plans.append((cur_state, actions, ground_actions(executor, actions, assignments)))
found_plan = True
else:
all_possible_plans.append((cur_state, new_csp, actions))
found_plan = True
if found_plan:
prefix_stop_mark[regression_rule_index] = True
some_rule_success = True
# TODO(Jiayuan Mao @ 2023/09/06): since we have changed the order of prefix_length for-loop and the regression rule for-loop.
# We need to use an additional dictionary to store whether we have found a plan for a particular regression rule.
# Right now this doesn't matter because we only use the first plan.
if not return_all and some_rule_success:
break
# Break for-loop for `for prev_state in previous_possible_branches`.
if not return_all and some_rule_success:
break
# Break for-loop for `for rule in regression_rules`
if not return_all and some_rule_success:
break
# Break for-loop for `for prefix_length in range(1, rule.max_rule_prefix_length + 1):`
if not return_all and some_rule_success:
break
if len(all_possible_plans) == 0:
if verbose:
jacinle.log_function.print('No possible plans for goal', g)
return return_with_cache(g, previous_actions, [])
# TODO(Jiayuan Mao @ 2023/11/19): add unique back.
# unique_all_possible_plans = _unique_plans(all_possible_plans)
unique_all_possible_plans = all_possible_plans
if len(unique_all_possible_plans) != len(all_possible_plans):
if verbose:
jacinle.log_function.print('Warning: there are duplicate plans for goal', g, f'({len(unique_all_possible_plans)} unique plans vs {len(all_possible_plans)} total plans)')
# import ipdb; ipdb.set_trace()
unique_all_possible_plans = sorted(unique_all_possible_plans, key=lambda x: len(x[2]))
return return_with_cache(g, previous_actions, unique_all_possible_plans)
if is_and_expr(goal_expr):
if len(goal_expr.arguments) == 1 and goal_expr.arguments[0].return_type.is_list_type:
goal_set = [goal_expr]
else:
goal_set = list(goal_expr.arguments)
else:
goal_set = [goal_expr]
goal_set = PartiallyOrderedPlan((TotallyOrderedPlan(goal_set, is_ordered=is_goal_serialized),))
candidate_plans = dfs(state, goal_set, tuple(), csp=ConstraintSatisfactionProblem() if enable_csp else None, previous_actions=list(), tail_csp_solve=True)
candidate_plans = [actions for final_state, csp, actions in candidate_plans]
return candidate_plans, search_stat
def _create_find_expression_variable(variable: Variable, csp: ConstraintSatisfactionProblem, bounded_variables: Dict[Variable, Any]):
"""Create a TensorValue that corresponds to a variable inside a `FindExpression`.
Args:
variable: the variable in the FindExpression.
csp: the current CSP.
bounded_variables: the already bounded variables.
"""
if variable.dtype.is_list_type:
length = -1
for v in bounded_variables.values():
if isinstance(v, ListValue):
length = len(v)
if length == -1:
raise ValueError(f'Cannot create a list variable {variable} without specifying the length.')
return ListValue(variable.dtype, [TensorValue.from_optimistic_value(csp.new_actionable_var(variable.dtype.element_type, wrap=True)) for _ in range(length)])
else:
return TensorValue.from_optimistic_value(csp.new_actionable_var(variable.dtype, wrap=True))
def _create_find_expression_variable_placeholder(variable: Variable, bounded_variables: Dict[Variable, Any]):
"""Create a TensorValue that corresponds to a variable inside a `FindExpression`. Unlike `_create_find_expression_variable`, this function only creates placeholder variables.
Args:
variable: the variable in the FindExpression.
bounded_variables: the already bounded variables.
"""
if variable.dtype.is_list_type:
length = -1
for v in bounded_variables.values():
if isinstance(v, ListValue):
length = len(v)
if length == -1:
raise ValueError(f'Cannot create a list variable {variable} without specifying the length.')
return ListValue(variable.dtype, [UnnamedPlaceholder(variable.dtype) for _ in range(length)])
else:
return UnnamedPlaceholder(variable.dtype)
def _mark_solver(executor: PDSketchExecutor, state: State, bounded_variables: Dict[Variable, Any], group: GroupConstraint):
"""Mark the solver for the current state.
Args:
executor: the executor.
state: the current state.
group: the current group constraint.
"""
for generator in executor.domain.generators.values():
if (matching := surface_fol_downcast(generator.certifies, group.expression)) is not None:
matching_success = True
inputs = list()
outputs = list()
for var in generator.context:
value = executor.execute(var, state, bounded_variables, optimistic_execution=True)
if _has_optimistic_value_or_list(value):
matching_success = False
break
inputs.append(value)
for var in generator.generates:
this_matching_success = False
if isinstance(var, VariableExpression):
if var.name in matching and _is_single_optimistic_value_or_list(matching[var.name]):
this_matching_success = True
outputs.append(_cvt_single_optimistic_value_or_list(matching[var.name]))
if not this_matching_success:
matching_success = False
if matching_success:
group.candidate_generators.append((generator, inputs, outputs))
def _has_optimistic_value_or_list(x: Union[ListValue, TensorValue]) -> bool:
if isinstance(x, ListValue):
return any(_has_optimistic_value_or_list(y) for y in x.values)
elif isinstance(x, TensorValue):
return x.has_optimistic_value()
else:
raise ValueError(f'Unknown value type {type(x)}')
def _is_single_optimistic_value_or_list(x: Union[ListValue, TensorValue]) -> bool:
if isinstance(x, ListValue):
return all(_is_single_optimistic_value_or_list(y) for y in x.values)
elif isinstance(x, TensorValue):
return x.is_single_optimistic_value()
else:
raise ValueError(f'Unknown value type {type(x)}')
def _cvt_single_optimistic_value_or_list(x: Union[ListValue, TensorValue]) -> Union[ListValue, OptimisticValue]:
if isinstance(x, ListValue):
return ListValue(x.dtype, [_cvt_single_optimistic_value_or_list(y) for y in x.values])
elif isinstance(x, TensorValue):
return x.single_elem()
else:
raise ValueError(f'Unknown value type {type(x)}')
def _has_optimistic_constant_expression(*expressions: ValueOutputExpression):
"""Check if there is a ConstantExpression whose value is an optimistic constant. Useful when checking if the subgoal is fully "gronded." """
for expression in expressions:
for e in iter_exprs(expression):
if isinstance(e, ConstantExpression) and _has_optimistic_value_or_list(e.constant):
return True
return False
def _len_candidate_regression_rules(candidate_regression_rules: List[ApplicableRegressionRuleGroup]) -> int:
"""Compute the number of candidate regression rules."""
return sum(len(x.regression_rules) for x in candidate_regression_rules)
class _ReplaceCSPVariableVisitor(IdentityExpressionVisitor):
def __init__(self, csp: ConstraintSatisfactionProblem, previous_csp: ConstraintSatisfactionProblem, variable_mapping: Dict[int, Any]):
self.csp = csp
self.previous_csp = previous_csp
self.variable_mapping = variable_mapping
def _replace_opt_value(self, value: Any):
if isinstance(value, ListValue):
return ListValue(value.dtype, [self._replace_opt_value(x) for x in value.values])
elif isinstance(value, TensorValue):
if value.is_single_optimistic_value():
identifier = value.single_elem().identifier
if identifier in self.variable_mapping:
return self.variable_mapping[identifier]
else:
self.variable_mapping[identifier] = TensorValue.from_optimistic_value(self.csp.new_actionable_var(value.dtype, wrap=True))
return self.variable_mapping[identifier]
else:
return value
else:
raise ValueError(f'Unknown value type {type(value)}')
def visit_constant_expression(self, expr: ConstantExpression) -> ConstantExpression:
return ConstantExpression(self._replace_opt_value(expr.constant))
# subgoal, new_csp_variable_mapping = _map_csp_placeholder_goal(item.goal, new_csp, placeholder_csp, placeholder_bounded_variables, cur_bounded_variables, csp_variable_mapping)
def _map_csp_placeholder_goal(
subgoal: ValueOutputExpression, csp: ConstraintSatisfactionProblem,
placeholder_csp: ConstraintSatisfactionProblem, placeholder_bounded_variables: Dict[Variable, Any],
csp_variable_mapping: Dict[int, TensorValue]
) -> Tuple[ValueOutputExpression, Dict[int, TensorValue]]:
"""Map the CSP variables in the subgoal to the CSP variables in the placeholder CSP."""
new_mapping = csp_variable_mapping.copy()
visitor = _ReplaceCSPVariableVisitor(csp, placeholder_csp, new_mapping)
new_subgoal = visitor.visit(subgoal)
return new_subgoal, new_mapping
def _map_csp_placeholder_action(
action: OperatorApplier, csp: ConstraintSatisfactionProblem,
placeholder_csp: ConstraintSatisfactionProblem, placeholder_bounded_variables: Dict[Variable, Any],
csp_variable_mapping: Dict[int, TensorValue]
) -> Tuple[OperatorApplier, Dict[int, TensorValue]]:
"""Map the CSP variables in the action to the CSP variables in the placeholder CSP."""
new_mapping = csp_variable_mapping.copy()
new_arguments = list()
for value in action.arguments:
if isinstance(value, TensorValue):
if value.is_single_optimistic_value():
identifier = value.single_elem().identifier
if identifier in new_mapping:
new_arguments.append(new_mapping[identifier])
else:
new_mapping[identifier] = TensorValue.from_optimistic_value(csp.new_actionable_var(value.dtype, wrap=True))
new_arguments.append(new_mapping[identifier])
else:
new_arguments.append(value)
else:
new_arguments.append(value)
new_action = OperatorApplier(action.operator, new_arguments)
return new_action, new_mapping
def _map_csp_variable_mapping(
csp_variable_mapping: Dict[int, TensorValue], csp: ConstraintSatisfactionProblem, assignments: AssignmentDict
) -> Dict[int, TensorValue]:
"""Map the CSP variable mapping to the new variable mapping."""
new_mapping = dict()
for identifier, value in csp_variable_mapping.items():
if isinstance(value, TensorValue):
if value.is_single_optimistic_value():
new_identifier = value.single_elem().identifier
if new_identifier in assignments:
new_value = csp.ground_assignment_value_partial(assignments, new_identifier)
if isinstance(new_value, OptimisticValue):
new_mapping[identifier] = TensorValue.from_optimistic_value(new_value)
elif isinstance(new_value, TensorValue):
new_mapping[identifier] = new_value
else:
raise TypeError(f'Unknown value type {type(new_value)}')
else:
new_mapping[identifier] = value
else:
raise TypeError(f'Unknown value type {type(value)}')
return new_mapping
def _map_csp_variable_state(
state: State, csp: ConstraintSatisfactionProblem, assignments: AssignmentDict
) -> State:
"""Map the CSP variable state to the new variable state."""
new_state = state.clone()
for feature_name, tensor_value in new_state.features.items():
if tensor_value.tensor_optimistic_values is None:
continue
for ind in torch.nonzero(tensor_value.tensor_optimistic_values).tolist():
ind = tuple(ind)
identifier = tensor_value.tensor_optimistic_values[ind].item()
if identifier in assignments:
new_value = csp.ground_assignment_value_partial(assignments, identifier)
if isinstance(new_value, OptimisticValue):
tensor_value.tensor_optimistic_values[ind] = new_value.identifier
elif isinstance(new_value, TensorValue):
tensor_value.tensor[ind] = new_value.tensor
tensor_value.tensor_optimistic_values[ind] = 0
else:
raise TypeError(f'Unknown value type {type(new_value)}')
return new_state