Source code for omnisafe.common.statistics_tools

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the statistics tools."""

from __future__ import annotations

import itertools
import json
import os
from copy import deepcopy
from typing import Any, Generator

from omnisafe.utils.plotter import Plotter
from omnisafe.utils.tools import assert_with_exit, hash_string, recursive_dict2json, update_dict


[docs]class StatisticsTools: """Analyze experiments results launched by experiment grid. Users can choose any parameters to compare the results. Aiming to help users to find the best hyperparameter faster. Attributes: grid_config_dir (str): The directory of grid_config.json. decompressed_grid_config (dict[str, Any]): The decompressed grid_config.json. path_map_img_name (dict[str, Any]): The map from path to image name. grid_config (dict[str, Any]): The grid_config.json. exp_dir (str): The experiment directory. plotter (Plotter): The plotter. """ grid_config_dir: str decompressed_grid_config: dict[str, Any] path_map_img_name: dict[str, Any] grid_config: dict[str, Any] exp_dir: str def __init__(self) -> None: """Initialize an instance of :class:`StatisticsTools`.""" self.plotter: Plotter = Plotter()
[docs] def load_source(self, path: str) -> None: """Load experiment results. Args: path (str): The experiment directory. """ # recursively find directory which is generated by experiment grid grid_config_dirs = [] for root, _, files in os.walk(path): if 'grid_config.json' in files: grid_config_dirs.append(files) self.grid_config_dir = os.path.join(root, 'grid_config.json') self.exp_dir = root assert_with_exit( hasattr(self, 'grid_config_dir'), 'cannot find directory which is initialized by experiment grid via grid_config.json', ) assert_with_exit( len(grid_config_dirs) == 1, 'there should be only one experiment grid directory', ) # load the config file of experiment grid try: with open(self.grid_config_dir, encoding='utf-8') as file: self.grid_config = json.load(file) except FileNotFoundError as error: raise FileNotFoundError( 'The config file is not found in the save directory.', ) from error
[docs] def draw_graph( self, parameter: str, values: list[Any] | None = None, compare_num: int | None = None, cost_limit: float | None = None, smooth: int = 1, show_image: bool = False, ) -> None: """Draw graph. Args: parameter (str): The parameter to compare. values (list[Any] or None, optional): The values of the parameter to compare. Defaults to None. compare_num (int or None, optional): The number of values to compare. Defaults to None. cost_limit (float or None, optional): The cost limit of the experiment. Defaults to None. smooth (int, optional): The smooth window size. Defaults to 1. show_image (bool): Whether to show graph image in GUI windows. .. note:: `values` and `compare_num` cannot be set at the same time. """ # check whether operation is valid assert_with_exit( not (values and compare_num), 'values and compare_num cannot be set at the same time', ) assert_with_exit(hasattr(self, 'grid_config'), 'please load source first') assert_with_exit( parameter in self.grid_config, f'parameter scope `{parameter}` is not in {self.grid_config}', ) # decompress the grid config decompressed_cfgs: dict = {} for k, v in self.grid_config.items(): update_dict(decompressed_cfgs, self.decompress_key(k, v)) self.decompressed_grid_config = decompressed_cfgs parameter_values = self.get_compressed_key(self.decompressed_grid_config, parameter) # make config groups via the combination of parameter values if not (values or compare_num): compare_num = len(parameter_values) graph_paths = self.make_config_groups(parameter, parameter_values, values, compare_num) for graph_dict in graph_paths: legend = [] log_dirs = [] img_name_cfgs = {} for (_, value), path in graph_dict.items(): legend += [f'{value}'] log_dirs += [path] img_name_cfgs = self.path_map_img_name[list(graph_dict.values())[-1]] decompressed_img_name_cfgs: dict = {} for k, v in img_name_cfgs.items(): update_dict(decompressed_img_name_cfgs, self.decompress_key(k, v[0])) save_name = ( list(graph_dict.keys())[-1][0][:10] # pylint: disable=undefined-loop-variable + '---' + decompressed_img_name_cfgs['env_id'][:30] + '---' + hash_string(recursive_dict2json(decompressed_img_name_cfgs)) ) try: self.plotter.make_plots( log_dirs, legend, 'Steps', 'Rewards', False, cost_limit, smooth, None, None, 'mean', save_name=save_name, show_image=show_image, ) except Exception: # noqa # pragma: no cover # pylint: disable=broad-except print( f'Cannot generate graph for {save_name[:5] + str(decompressed_img_name_cfgs)}', ) print(Exception)
[docs] def make_config_groups( self, parameter: str, parameter_values: list[str], values: list[Any] | None = None, compare_num: int | None = None, ) -> list[dict[tuple[str, Any], str]]: """Make config groups. Each group contains a list of config paths to compare. .. warning:: `values` and `compare_num` cannot be set at the same time. Args: parameter (str): The parameter to compare. parameter_values (list[str]): The values of the parameter to compare. values (list[Any] or None, optional): The values of the parameter to compare. Defaults to None. compare_num (int or None, optional): The number of values to compare. Defaults to None. Returns: A list of graph paths. """ self.path_map_img_name = {} parameter_values_combination: list[tuple] = [] graph_groups: list[list] = [] assert (values is not None) ^ ( compare_num is not None ), 'The values and compare_num cannot be set at the same time' if values: assert_with_exit( all(v in parameter_values for v in values), f'values `{values}` of parameter `{parameter}` is not subset of `{parameter_values}`', ) # if values is specified, will only compare values in it parameter_values_combination = [tuple(values)] if compare_num: assert_with_exit( compare_num <= len(parameter_values), ( f'compare_num `{compare_num}` is larger than number of values ' f'`{len(parameter_values)}` of parameter `{parameter}`' ), ) # if compare_num is specified, will combine any potential combination to compare parameter_values_combination = list(self.combine(parameter_values, compare_num)) group_config = deepcopy(self.grid_config) # value of parameter is determined above group_config.pop(parameter) # seed is not a parameter if 'seed' in group_config: group_config.pop('seed') if 'train_cfgs' in group_config: group_config['train_cfgs'].pop('device', None) # combine all possible combinations of other parameters # fix them in a single graph and only vary values of parameter which is specified by us for pinned_config in self.dict_permutations(group_config): group_config.update(pinned_config) for compare_value in parameter_values_combination: group_config[parameter] = list(compare_value) img_name_cfgs = deepcopy(group_config) graph_groups.append( [ img_name_cfgs, self.variants(list(group_config.keys()), list(group_config.values())), ], ) graph_paths = [] for img_name_cfgs, graph in graph_groups: paths = {} for path_dict in graph: exp_name = ( path_dict['env_id'][:30] + '---' + hash_string(recursive_dict2json(path_dict)) ) path = os.path.join(self.exp_dir, exp_name) self.path_map_img_name[path] = img_name_cfgs para_val = (parameter, self.get_compressed_key(path_dict, parameter)) paths[para_val] = path graph_paths.append(paths) return graph_paths
[docs] def decompress_key(self, compressed_key: str, value: Any) -> dict[str, Any]: """This function is used to convert the custom configurations to dict. .. note:: This function is used to convert the custom configurations to dict. For example, if the custom configurations are ``train_cfgs:use_wandb`` and ``True``, then the output dict will be ``{'train_cfgs': {'use_wandb': True}}``. Args: compressed_key (str): The compressed key. value (Any): The value of the compressed key. Returns: The decompressed dict. """ keys_split = compressed_key.replace('-', '_').split(':') return_dict = {keys_split[-1]: value} for key in reversed(keys_split[:-1]): return_dict = {key.replace('-', '_'): return_dict} return return_dict
[docs] def _variants(self, keys: list[str], vals: list[Any]) -> list[dict[str, Any]]: """Recursively builds list of valid variants. Args: keys (list[str]): The keys of the config. vals (list[Any]): The values of the config. Returns: List of valid variants. """ if len(keys) == 1: pre_variants: list[dict[str, Any]] = [{}] else: pre_variants = self._variants(keys[1:], vals[1:]) variants = [] for val in vals[0]: for pre_v in pre_variants: current_variants = deepcopy(pre_v) v_temp = {} key_list = keys[0].split(':') v_temp[key_list[-1]] = val for key in reversed(key_list[:-1]): v_temp = {key: v_temp} self.update_dict(current_variants, v_temp) variants.append(current_variants) return variants
[docs] def update_dict(self, total_dict: dict[str, Any], item_dict: dict[str, Any]) -> None: """Updater of multi-level dictionary. Args: total_dict (dict[str, Any]): The total dictionary. item_dict (dict[str, Any]): The item dictionary. """ for idd in item_dict: total_value = total_dict.get(idd) item_value = item_dict.get(idd) if total_value is None: total_dict.update({idd: item_value}) elif isinstance(item_value, dict): self.update_dict(total_value, item_value) total_dict.update({idd: total_value}) else: total_value = item_value total_dict.update({idd: total_value})
[docs] def variants(self, keys: list[str], vals: list[Any]) -> list[dict[str, Any]]: """Makes a list of dict, where each dict is a valid config in the grid. There is special handling for variant parameters whose names take the form ``'full:param:name'`` The colons are taken to indicate that these parameters should have a nested dict structure. For example, if there are two params, ==================== === Key Val ==================== === ``'base:param:a'`` 1 ``'base:param:b'`` 2 ==================== === the variant dict will have the structure .. parsed-literal:: variant = { base: { param : { a : 1, b : 2 } } } Args: keys (list[str]): The keys of the config. vals (list[Any]): The values of the config. Returns: List of valid and not duplicate variants. """ flat_variants = self._variants(keys, vals) def check_duplicate(var: dict[str, Any]) -> dict[str, Any]: """Build the full nested dict version of var, based on key names.""" new_var: dict = {} for key, value in var.items(): assert key not in new_var, "You can't assign multiple values to the same key." new_var[key] = value return new_var return [check_duplicate(var) for var in flat_variants]
[docs] def combine(self, sequence: list[str], num_choosen: int) -> Generator: """Combine elements in sequence to n elements. Args: sequence (list[str]): The sequence to be combined. num_choosen (int): The number of elements to be combined. Returns: The generator of the combined elements. """ if num_choosen == 1: for i in sequence: yield (i,) else: for idx, item in enumerate(sequence): for nxt in self.combine(sequence[idx + 1 :], num_choosen - 1): yield (item, *nxt)
[docs] def dict_permutations(self, input_dict: dict[str, Any]) -> list[dict[str, Any]]: """Generate all possible combinations of the values in a dictionary. Takes a dictionary with string keys and list values, and returns a dictionary with all possible combinations of the lists as values for each key. Args: input_dict (dict[str, Any]): The input dictionary. Returns: The list of all possible combinations of the values in a dictionary. """ keys = list(input_dict.keys()) values = list(input_dict.values()) value_combinations = list(itertools.product(*values)) result = [] for combination in value_combinations: new_dict = {} for i, item in enumerate(keys): new_dict[item] = [combination[i]] result.append(new_dict) return result
[docs] def get_compressed_key(self, dictionary: dict[str, Any], key: str) -> Any: """Get the compressed value of the key. Args: dictionary (dict[str, Any]): the uncompressed dictionary. key (str): the key. Returns: The compressed value of the key. """ inner_config = dictionary for k in key.split(':'): inner_config = inner_config[k] return inner_config