Source code for concepts.pdsketch.strips.atomic_strips_regression_search

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : atomic_strips_regression_search.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 07/13/2023
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

import itertools
from typing import List, Set

import jacinle

from concepts.pdsketch.strips.strips_expression import SProposition
from concepts.pdsketch.strips.atomic_strips_domain import (
    AtomicStripsDomain, AtomicStripsProblem,
    AtomicStripsOperator, AtomicStripsRegressionRule,
    AtomicStripsOperatorApplier, AtomicStripsRegressionRuleApplier,
    AtomicStripsGroundedAchieveExpression
)


[docs] def gen_all_grounded_actions_and_rules(domain: AtomicStripsDomain, problem: AtomicStripsProblem, mode: str) -> List[AtomicStripsOperatorApplier]: assert mode in ('action', 'rule') all_actions = list() if mode == 'action': operator_list = domain.operators elif mode == 'rule': operator_list = domain.regression_rules else: raise ValueError('Unknown mode: {}.'.format(mode)) for operator in operator_list.values(): candidate_arguments = [ problem.objects_type2names[operator.arguments[i].dtype.typename] for i in range(len(operator.arguments)) ] for arg_list in itertools.product(*candidate_arguments): action = operator(*arg_list) static_check = True for i, prop in enumerate(action.preconditions): if domain.predicates[operator.preconditions[i].name].is_static: if prop not in problem.initial_state: static_check = False break if not static_check: continue all_actions.append(action) return all_actions
[docs] def astrips_regression_search_1(domain: AtomicStripsDomain, problem: AtomicStripsProblem, verbose: bool = False) -> List[AtomicStripsOperatorApplier]: """Search for a plan using regression search.""" assert problem.conjunctive_goal is not None, "Only conjunctive goals are supported." assert len(problem.conjunctive_goal) == 1, "Only single-goal problems are supported." for operator in domain.operators.values(): for precondition in operator.preconditions: if precondition.negated: raise NotImplementedError('astrips_regression_search does not support negated preconditions.') for regression_rule in domain.regression_rules.values(): for precondition in regression_rule.preconditions: if precondition.negated: raise NotImplementedError('astrips_regression_search does not support negated preconditions.') all_rules = gen_all_grounded_actions_and_rules(domain, problem, 'rule') def find_applicable_rules(state: Set[SProposition], goal: SProposition, maintains: Set[SProposition]): applicable_rules = list() for rule in all_rules: if rule.goal == goal and state.issuperset(rule.preconditions): applicable_rules.append(rule) assert len(applicable_rules) == 1, "Only one applicable rule is allowed." return applicable_rules[0] @jacinle.log_function(verbose=False) def dfs(state: Set[SProposition], goal: SProposition, maintains: Set[SProposition]): rule = find_applicable_rules(state, goal, maintains) if verbose: jacinle.log_function.print(f'Current state: {state}, goal: {goal}, maintains: {maintains} => rule: {rule}') actions = list() for item in rule.body: if isinstance(item, AtomicStripsGroundedAchieveExpression): if item.goal not in state: if verbose: jacinle.log_function.print(f'{str(item)}') state, sub_actions = dfs(state, item.goal, maintains.union(item.maintains)) actions.extend(sub_actions) else: if verbose: jacinle.log_function.print(f'{str(item)} (skipped)') elif isinstance(item, AtomicStripsOperatorApplier): if verbose: jacinle.log_function.print(f'do({str(item)})') actions.append(item) assert state.issuperset(item.preconditions), "The preconditions of an action should be satisfied." state = (state - frozenset(item.del_effects)).union(frozenset(item.add_effects)) else: raise ValueError('Unknown item: {}.'.format(item)) return state, actions end_state, actions = dfs(problem.initial_state, problem.conjunctive_goal[0], set()) return actions