Source code for concepts.pdsketch.strips.atomic_strips_onthefly_search

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : atomic_strips_onthefly_search.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 07/14/2023
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

import time
from typing import Optional, Union, Iterator, Tuple, List, Dict
from collections import defaultdict, deque

from concepts.dsl.dsl_types import Variable
from concepts.pdsketch.strips.strips_expression import SStateDict, SBoolPredicateApplicationExpression
from concepts.pdsketch.strips.atomic_strips_domain import AtomicStripsDomain, AtomicStripsProblem, AtomicStripsOperator, AtomicStripsOperatorApplier


def _bind_arguments(predicate: SBoolPredicateApplicationExpression, bound_arguments: Dict[str, Union[int, str]]):
    return predicate.name, tuple(bound_arguments[arg.name] if isinstance(arg, Variable) else arg for arg in predicate.arguments)


def _gen_applicable_actions(domain: AtomicStripsDomain, objects: Dict[str, List[str]], state: SStateDict, check_negation: bool = False) -> Iterator[Tuple[AtomicStripsOperator, Dict[str, str]]]:
    TOO_MANY, FAILED, PASS = object(), object(), object()

    def compute_possible_grounding(predicate: SBoolPredicateApplicationExpression, bound_arguments: Dict[str, Union[int, str]]):
        unbound_arguments = [arg for arg in predicate.arguments if isinstance(arg, Variable) and arg.name not in bound_arguments]
        if len(unbound_arguments) == 0:
            name, arguments = _bind_arguments(predicate, bound_arguments)
            rv = state.contains(name, arguments, predicate.negated, check_negation=check_negation)
            if not rv:
                return '', FAILED
            return '', PASS
        elif len(unbound_arguments) == 1:
            arg = unbound_arguments[0]
            valid_arguments = list()

            options = objects[arg.typename]
            for o in options:
                bound_arguments[arg.name] = o
                name, arguments = _bind_arguments(predicate, bound_arguments)
                rv = state.contains(name, arguments, predicate.negated, check_negation=check_negation)
                if rv:
                    valid_arguments.append(o)
                del bound_arguments[arg.name]
            return arg.name, valid_arguments
        else:
            return '', TOO_MANY

    # @jacinle.log_function(verbose=False)
    def dfs(preconditions: Tuple[SBoolPredicateApplicationExpression, ...], bound_arguments: Dict[str, int]):
        """Inner DFS function.

        Args:
            preconditions: the preconditions to be satisfied.
            bound_arguments: a mapping from variable name to object.
        """

        # jacinle.log_function.print('dfs', bound_arguments, 'remaining preconditions:', len(preconditions))
        # import ipdb; ipdb.set_trace()

        for i, precondition in enumerate(preconditions):
            name, valid_arguments = compute_possible_grounding(precondition, bound_arguments)
            if valid_arguments == FAILED:
                # jacinle.log_function.print('Failed.')
                return list()
            elif valid_arguments == PASS:
                # jacinle.log_function.print('Pass.')
                return dfs(preconditions[:i] + preconditions[i + 1:], bound_arguments)
            elif valid_arguments == TOO_MANY:
                pass
            else:
                outputs = list()
                for arg in valid_arguments:
                    bound_arguments[name] = arg
                    outputs.extend(dfs(preconditions[:i] + preconditions[i + 1:], bound_arguments))
                    del bound_arguments[name]
                return outputs

        unbound_arguments = [arg for arg in operator.arguments if isinstance(arg, Variable) and arg.name not in bound_arguments]
        # print('unbound_arguments', unbound_arguments, bound_arguments)
        if len(unbound_arguments) == 0:
            # jacinle.log_function.print('Found a grounding:', bound_arguments)
            return [bound_arguments.copy()]

        unbound_arguments_possible_values = {arg.name: objects[arg.typename] for arg in unbound_arguments}

        name, valid_arguments = min(unbound_arguments_possible_values.items(), key=lambda x: len(x[1]))
        outputs = list()
        for arg in valid_arguments:
            bound_arguments[name] = arg
            # jacinle.log_function.print('{} = {}'.format(name, arg))
            outputs.extend(dfs(preconditions, bound_arguments))
        del bound_arguments[name]
        return outputs

    for operator in domain.operators.values():
        # jacinle.log_function.print(f'operator: {operator.name}')
        for bound_arguments in dfs(operator.preconditions, dict()):
            # jacinle.log_function.print('yield bound_arguments:', bound_arguments)
            yield operator, bound_arguments


def _check_precondition(state: SStateDict, operator: AtomicStripsOperator, bound_arguments: Dict[str, Union[int, str]]):
    for precondition in operator.preconditions:
        name, arguments = _bind_arguments(precondition, bound_arguments)
        if not state.contains(name, arguments, precondition.negated):
            return False
    return True


def _apply_operator(state: SStateDict, operator: AtomicStripsOperator, bound_arguments: Dict[str, Union[int, str]]):
    new_state = state.clone()
    for predicate in operator.del_effects:
        name, arguments = _bind_arguments(predicate, bound_arguments)
        new_state.remove(name, arguments)
    for predicate in operator.add_effects:
        name, arguments = _bind_arguments(predicate, bound_arguments)
        new_state.add(name, arguments)
    return new_state


def _ground_actions(actions: Tuple[Tuple[AtomicStripsOperator, Dict[str, str]], ...]) -> Tuple[AtomicStripsOperatorApplier, ...]:
    ground_operators = list()
    for operator, bound_arguments in actions:
        ground_operators.append(operator.ground(bound_arguments))
    return tuple(ground_operators)