#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : regression_dependency.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 08/5/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
"""Dependencies for the regression planning."""
import tempfile
import os
from dataclasses import dataclass
from typing import Optional, Sequence, Tuple, List, Dict
from jacinle.utils.printing import indent_text
from concepts.dm.crow.planners.regression_planning import SupportedCrowExpressionType
from concepts.dm.crow.behavior_utils import format_behavior_statement
[docs]
@dataclass(unsafe_hash=True)
class RegressionTraceStatement(object):
stmt: SupportedCrowExpressionType
scope_id: int = None
new_scope_id: Optional[int] = None
additional_info: Optional[str] = None
scope: Optional[dict] = None
derived_from: Optional[SupportedCrowExpressionType] = None
[docs]
def node_string(self, scopes: Dict[int, dict]) -> str:
scope_id = self.new_scope_id if self.new_scope_id is not None else self.scope_id
basic_fmt = format_behavior_statement(self.stmt, scopes, scope_id)
if self.derived_from is not None:
basic_fmt += '\n derived from: ' + indent_text(format_behavior_statement(self.derived_from, scopes, scope_id), 1, indent_first=False)
if self.additional_info is not None:
basic_fmt += '\n note: ' + self.additional_info
return basic_fmt
[docs]
class RegressionDependencyGraph(object):
[docs]
def __init__(self, scopes: Dict[int, dict]):
self.scopes = scopes
self.nodes = list()
self.node2index = dict()
self.edges = dict()
nodes: List[RegressionTraceStatement]
node2index: Dict[RegressionTraceStatement, int]
edges: Dict[RegressionTraceStatement, List[int]]
[docs]
def add_node(self, node: RegressionTraceStatement) -> 'RegressionDependencyGraph':
self.nodes.append(node)
self.node2index[node] = len(self.nodes) - 1
return self
[docs]
def connect(self, x: RegressionTraceStatement, y: RegressionTraceStatement) -> 'RegressionDependencyGraph':
"""Connect two nodes in the dependency graph. x is the "parent" of y.
Args:
x: the parent node.
y: the child node.
"""
self.edges.setdefault(x, list()).append(self.node2index[y])
return self
[docs]
def print(self, i: int = 0, indent_level: int = 0) -> None:
print(indent_text(f'{i}::' + self.nodes[i].node_string(self.scopes), indent_level))
for child in self.edges.get(self.nodes[i], []):
self.print(child, indent_level + 1)
[docs]
def sort_nodes_into_levels(self):
levels = dict()
def dfs(i):
max_level = -1
for child in self.edges.get(self.nodes[i], []):
max_level = max(dfs(child), max_level)
levels[i] = max_level + 1
return max_level + 1
max_level = dfs(0)
output_levels = list()
for i in range(max_level + 1):
output_levels.append([j for j in range(len(self.nodes)) if levels[j] == i])
return output_levels
[docs]
def render_graphviz(self, filename: Optional[str] = None) -> None:
try:
import graphviz
except ImportError:
raise ImportError('Please install graphviz first by running "pip install graphviz".')
dot = graphviz.Digraph(comment='Regression Dependency Graph')
for i, node in enumerate(self.nodes):
dot.node(str(i), node.node_string(self.scopes).replace('\n', '\l') + '\l', shape='rectangle')
levels = self.sort_nodes_into_levels()
for i in range(len(levels)):
dot.node(f'level_{i}', '', ordering='out', style='invis')
for j in levels[i]:
dot.edge(f'level_{i}', str(j), style='invis')
for i in reversed(range(len(levels))):
if i > 0:
dot.edge(f'level_{i}', f'level_{i - 1}', style='invis')
for x, ys in self.edges.items():
for y in ys:
dot.edge(str(self.node2index[x]), str(y))
if filename is not None:
if filename.endswith('.png'):
actual_filename = filename[:-4]
dot.render(actual_filename, format='png', cleanup=True)
print(f'Graphviz file saved to "{filename}".')
elif filename.endswith('.pdf'):
actual_filename = filename[:-4]
dot.render(actual_filename, format='pdf', cleanup=True)
print(f'Graphviz file saved to "{filename}".')
elif filename.endswith('.dot'):
dot.render(filename)
print(f'Graphviz file saved to "{filename}".')
else:
raise ValueError(f'Unsupported file format: {filename}. Only PNG, PDF, and DOT are supported.')
else:
with tempfile.NamedTemporaryFile(suffix='.pdf') as f:
dot.render(f.name[:-4], format='pdf', cleanup=True)
print(f'Graphviz file saved to "{f.name}". Now opening it in the default PDF viewer...')
os.system(f'open "{f.name}"')
import time; time.sleep(3) # We need to sleep for a while to prevent the file from being deleted too early.
[docs]
def recover_graph_from_trace(trace: Sequence[RegressionTraceStatement], scopes: Dict[int, dict]) -> RegressionDependencyGraph:
graph = RegressionDependencyGraph(scopes)
scope_to_node = dict()
graph.add_node(trace[0])
scope_to_node[trace[0].new_scope_id] = trace[0]
for stmt in trace[1:]:
graph.add_node(stmt)
if stmt.scope_id in scope_to_node:
graph.connect(scope_to_node[stmt.scope_id], stmt)
if stmt.new_scope_id is not None:
scope_to_node[stmt.new_scope_id] = stmt
return graph