#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : atomic_strips_onthefly_search.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 07/14/2023
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
import time
from typing import Optional, Union, Iterator, Tuple, List, Dict
from collections import defaultdict, deque
from concepts.dsl.dsl_types import Variable
from concepts.dm.pdsketch.strips.strips_expression import SStateDict, SBoolPredicateApplicationExpression
from concepts.dm.pdsketch.strips.atomic_strips_domain import AtomicStripsDomain, AtomicStripsProblem, AtomicStripsOperator, AtomicStripsOperatorApplier
def _bind_arguments(predicate: SBoolPredicateApplicationExpression, bound_arguments: Dict[str, Union[int, str]]):
return predicate.name, tuple(bound_arguments[arg.name] if isinstance(arg, Variable) else arg for arg in predicate.arguments)
def _gen_applicable_actions(domain: AtomicStripsDomain, objects: Dict[str, List[str]], state: SStateDict, check_negation: bool = False) -> Iterator[Tuple[AtomicStripsOperator, Dict[str, str]]]:
TOO_MANY, FAILED, PASS = object(), object(), object()
def compute_possible_grounding(predicate: SBoolPredicateApplicationExpression, bound_arguments: Dict[str, Union[int, str]]):
unbound_arguments = [arg for arg in predicate.arguments if isinstance(arg, Variable) and arg.name not in bound_arguments]
if len(unbound_arguments) == 0:
name, arguments = _bind_arguments(predicate, bound_arguments)
rv = state.contains(name, arguments, predicate.negated, check_negation=check_negation)
if not rv:
return '', FAILED
return '', PASS
elif len(unbound_arguments) == 1:
arg = unbound_arguments[0]
valid_arguments = list()
options = objects[arg.typename]
for o in options:
bound_arguments[arg.name] = o
name, arguments = _bind_arguments(predicate, bound_arguments)
rv = state.contains(name, arguments, predicate.negated, check_negation=check_negation)
if rv:
valid_arguments.append(o)
del bound_arguments[arg.name]
return arg.name, valid_arguments
else:
return '', TOO_MANY
# @jacinle.log_function(verbose=False)
def dfs(preconditions: Tuple[SBoolPredicateApplicationExpression, ...], bound_arguments: Dict[str, int]):
"""Inner DFS function.
Args:
preconditions: the preconditions to be satisfied.
bound_arguments: a mapping from variable name to object.
"""
# jacinle.log_function.print('dfs', bound_arguments, 'remaining preconditions:', len(preconditions))
# import ipdb; ipdb.set_trace()
for i, precondition in enumerate(preconditions):
name, valid_arguments = compute_possible_grounding(precondition, bound_arguments)
if valid_arguments == FAILED:
# jacinle.log_function.print('Failed.')
return list()
elif valid_arguments == PASS:
# jacinle.log_function.print('Pass.')
return dfs(preconditions[:i] + preconditions[i + 1:], bound_arguments)
elif valid_arguments == TOO_MANY:
pass
else:
outputs = list()
for arg in valid_arguments:
bound_arguments[name] = arg
outputs.extend(dfs(preconditions[:i] + preconditions[i + 1:], bound_arguments))
del bound_arguments[name]
return outputs
unbound_arguments = [arg for arg in operator.arguments if isinstance(arg, Variable) and arg.name not in bound_arguments]
# print('unbound_arguments', unbound_arguments, bound_arguments)
if len(unbound_arguments) == 0:
# jacinle.log_function.print('Found a grounding:', bound_arguments)
return [bound_arguments.copy()]
unbound_arguments_possible_values = {arg.name: objects[arg.typename] for arg in unbound_arguments}
name, valid_arguments = min(unbound_arguments_possible_values.items(), key=lambda x: len(x[1]))
outputs = list()
for arg in valid_arguments:
bound_arguments[name] = arg
# jacinle.log_function.print('{} = {}'.format(name, arg))
outputs.extend(dfs(preconditions, bound_arguments))
del bound_arguments[name]
return outputs
for operator in domain.operators.values():
# jacinle.log_function.print(f'operator: {operator.name}')
for bound_arguments in dfs(operator.preconditions, dict()):
# jacinle.log_function.print('yield bound_arguments:', bound_arguments)
yield operator, bound_arguments
def _check_precondition(state: SStateDict, operator: AtomicStripsOperator, bound_arguments: Dict[str, Union[int, str]]):
for precondition in operator.preconditions:
name, arguments = _bind_arguments(precondition, bound_arguments)
if not state.contains(name, arguments, precondition.negated):
return False
return True
def _apply_operator(state: SStateDict, operator: AtomicStripsOperator, bound_arguments: Dict[str, Union[int, str]]):
new_state = state.clone()
for predicate in operator.del_effects:
name, arguments = _bind_arguments(predicate, bound_arguments)
new_state.remove(name, arguments)
for predicate in operator.add_effects:
name, arguments = _bind_arguments(predicate, bound_arguments)
new_state.add(name, arguments)
return new_state
def _ground_actions(actions: Tuple[Tuple[AtomicStripsOperator, Dict[str, str]], ...]) -> Tuple[AtomicStripsOperatorApplier, ...]:
ground_operators = list()
for operator, bound_arguments in actions:
ground_operators.append(operator.ground(bound_arguments))
return tuple(ground_operators)
[docs]
def astrips_onthefly_search(domain: AtomicStripsDomain, problem: AtomicStripsProblem, verbose: bool = False, timeout: float = 300, max_expanded_nodes: int = 1000000) -> Optional[Tuple[AtomicStripsOperatorApplier, ...]]:
objects = defaultdict(list)
object2index = dict()
for name, constant in domain.constants.items():
objects[constant.dtype.typename].append(name)
object2index[name] = len(object2index) - 1
for name, typename in problem.objects.items():
objects[typename].append(name)
object2index[name] = len(object2index) - 1
initial_state = SStateDict()
for predicate in problem.initial_state:
name, *args = predicate.split()
initial_state.add(name, args)
goal_conditions = problem.conjunctive_goal
frontier = deque()
frontier.append((initial_state, tuple()))
explored = set()
start_time = time.time()
nr_expanded_nodes = 0
while len(frontier) > 0:
nr_expanded_nodes += 1
if nr_expanded_nodes > max_expanded_nodes:
break
if nr_expanded_nodes % 100 == 0:
if time.time() - start_time > timeout:
print('astrips_onthefly_search::Timeout.')
break
state, plan = frontier.popleft()
# action_strings = [f"{operator.name}({', '.join(bound_arguments.values())})" for operator, bound_arguments in plan]
# print('State', state, 'Plan', action_strings)
# print('Plan', action_strings)
for operator, bound_arguments in _gen_applicable_actions(domain, objects, state):
new_state = _apply_operator(state, operator, bound_arguments)
new_state_set = new_state.as_state()
if new_state_set not in explored:
if new_state_set.issuperset(goal_conditions):
return _ground_actions(plan + ((operator, bound_arguments), ))
frontier.append((new_state, plan + ((operator, bound_arguments), )))
explored.add(new_state_set)
return None