Source code for concepts.pdsketch.crow.crow_state

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : crow_state.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 11/09/2023
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

from typing import Optional, Union, Iterator, Sequence, Tuple, List

import jacinle
from concepts.dsl.expression import ValueOutputExpression
from concepts.pdsketch.regression_rule import AchieveExpression, BindExpression, RegressionRuleApplier

__all__ = ['TotallyOrderedPlan', 'PartiallyOrderedPlan', 'RegressionNode']


[docs] class TotallyOrderedPlan(object): """A totally ordered plan sequence."""
[docs] def __init__(self, sequence: Sequence[Union[ValueOutputExpression, RegressionRuleApplier]], return_all_skeletons_flags: Optional[Union[Sequence[bool], bool]] = None, is_ordered: bool = True): """Initialize the totally ordered plan sequence. Args: sequence: the sequence of the plan. is_ordered: whether the sequence is ordered. """ self.sequence = tuple(sequence) if type(return_all_skeletons_flags) is bool: return_all_skeletons_flags = [return_all_skeletons_flags] * len(sequence) else: self.return_all_skeletons_flags = tuple(return_all_skeletons_flags) if return_all_skeletons_flags is not None else None self.is_ordered = is_ordered
sequence: Tuple[Union[ValueOutputExpression, RegressionRuleApplier], ...] """The sequence of the plan.""" return_all_skeletons_flags: Optional[Tuple[bool, ...]] """For each item in the sequence, whether to return all the skeletons. This value can be set to None when is_ordered is False.""" is_ordered: bool """Whether the sequence is ordered. If it's not ordered, then the sequence is treated as a set."""
[docs] def exclude(self, index: int): """Exclude the given index from the sequence; return a new sequence.""" return TotallyOrderedPlan( self.sequence[:index] + self.sequence[index + 1:], self.return_all_skeletons_flags[:index] + self.return_all_skeletons_flags[index + 1:] if self.return_all_skeletons_flags is not None else None, self.is_ordered )
[docs] def get_return_all_skeletons_flag(self, index: int) -> bool: """Get the return_all_skeletons flag of the given index. If the flag is not set, then return True.""" if self.return_all_skeletons_flags is None: return True return self.return_all_skeletons_flags[index]
[docs] def gen_string(self): """Generate the string representation of the plan.""" if self.is_ordered: return '(then ' + ' '.join([str(e) for e in self.sequence]) + ')' else: return '(and ' + ' '.join([str(e) for e in self.sequence]) + ')'
[docs] def __len__(self): return len(self.sequence)
def __str__(self): return self.gen_string() __repr__ = jacinle.repr_from_str
[docs] class PartiallyOrderedPlan(object): """A partially ordered plan is a set of totally ordered plan sequences."""
[docs] def __init__(self, chains: Sequence[TotallyOrderedPlan]): """Initialize the partially ordered plan. Args: chains: a collection of the totally ordered plan sequences. """ self.chains = tuple(chains) self.infeasible_chain_index = None
chains: Tuple[TotallyOrderedPlan, ...] """The totally ordered plan sequences.""" infeasible_index: Optional[int] """One can optionally mark one of the chain as infeasible to be the last chain. Currently this flag is not used."""
[docs] def exclude(self, chain_index: int, item_index: int) -> 'PartiallyOrderedPlan': """Exclude the given item from the given chain; return a new plan.""" if len(self.chains[chain_index]) == 1: assert item_index == 0 return PartiallyOrderedPlan(self.chains[:chain_index] + self.chains[chain_index + 1:]) return PartiallyOrderedPlan(self.chains[:chain_index] + (self.chains[chain_index].exclude(item_index),) + self.chains[chain_index + 1:])
[docs] def add_chain(self, chain: Sequence[Union[ValueOutputExpression, RegressionRuleApplier]], return_all_skeletons_flags: Optional[Sequence[bool]] = False) -> 'PartiallyOrderedPlan': """Add a new chain to the plan; return a new plan.""" plan = PartiallyOrderedPlan(self.chains + (TotallyOrderedPlan(chain, return_all_skeletons_flags),)) return plan
[docs] @classmethod def from_single_goal(cls, goal: Union[ValueOutputExpression, RegressionRuleApplier], return_all_skeletons_flag: bool = False) -> 'PartiallyOrderedPlan': """Create a plan from a single goal.""" return cls((TotallyOrderedPlan((goal,), [return_all_skeletons_flag], is_ordered=True),))
@property def nr_chains(self): return len(self.chains) @property def total_length(self): return sum(len(chain) for chain in self.chains)
[docs] def set_infeasible_index(self, chain_index: int): self.infeasible_chain_index = chain_index
[docs] def iter_feasible_chains(self) -> Iterator[Tuple[int, TotallyOrderedPlan]]: for i, seq in enumerate(self.chains): if i == self.infeasible_chain_index: continue yield i, seq
[docs] def iter_goals(self) -> Iterator[ValueOutputExpression]: for chain in self.chains: for goal in chain.sequence: yield goal if isinstance(goal, ValueOutputExpression) else goal.goal_expression
[docs] def gen_string(self): return '(and ' + ' '.join([chain.gen_string() for chain in self.chains]) + ')'
[docs] def __len__(self): return self.total_length
def __str__(self): return self.gen_string() __repr__ = jacinle.repr_from_str
[docs] class RegressionNode(object):
[docs] def __init__(self, goal_expression: Union[BindExpression, AchieveExpression], associated_regression_rule: Optional[RegressionRuleApplier] = None): self.goal_expression = goal_expression self.associated_regression_rule = associated_regression_rule self.children = []
goal_expression: Union[BindExpression, AchieveExpression] """The goal expression of this node.""" children: List['RegressionNode'] """The children of this node."""
[docs] def add_child(self, node: 'RegressionNode'): self.children.append(node)