Coverage for sparkle/platform/output/configuration_output.py: 99%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 14:48 +0000

1#!/usr/bin/env python3 

2"""Sparkle class to organise configuration output.""" 

3 

4from __future__ import annotations 

5 

6from sparkle.platform import \ 

7 generate_report_for_configuration as sgrfch 

8from sparkle.solver import Solver 

9from sparkle.instance import InstanceSet 

10from sparkle.configurator.configurator import Configurator, ConfigurationScenario 

11from sparkle.solver.validator import Validator 

12from sparkle.platform.output.structures import ValidationResults, ConfigurationResults 

13from sparkle.types import SolverStatus 

14 

15import json 

16from pathlib import Path 

17 

18 

19class ConfigurationOutput: 

20 """Class that collects configuration data and outputs it a JSON format.""" 

21 

22 def __init__(self: ConfigurationOutput, path: Path, 

23 configurator: Configurator, config_scenario: ConfigurationScenario, 

24 instance_set_test: InstanceSet, output: Path) -> None: 

25 """Initialize Configurator Output class. 

26 

27 Args: 

28 path: Path to configuration output directory 

29 configurator: The configurator that was used 

30 config_scenario: The scenario to output 

31 instance_set_test: Instance set used for testing 

32 output: Path to the output directory 

33 """ 

34 self.solver = config_scenario.solver 

35 self.configurator = configurator 

36 self.instance_set_train = config_scenario.instance_set 

37 self.instance_set_test = instance_set_test 

38 self.directory = path 

39 self.config_scenario = config_scenario 

40 self.output = output / "configuration.json" if not output.is_file() else output 

41 

42 solver_dir_name = path.name 

43 scenario_file = path / f"{solver_dir_name}_scenario.txt" 

44 if not scenario_file.is_file(): 

45 raise Exception("Can't find scenario file") 

46 

47 # Retrieve all configurations 

48 config_path = path / "validation" / "configurations.csv" 

49 self.configurations = self.get_configurations(config_path) 

50 

51 # Retrieve best found configuration 

52 _, self.best_config = self.configurator.get_optimal_configuration( 

53 self.config_scenario) 

54 

55 # Retrieves validation results for all configurations 

56 self.validation_results = [] 

57 for config in self.configurations: 

58 val_res = self.get_validation_data(self.instance_set_train, 

59 config) 

60 self.validation_results.append(val_res) 

61 

62 # Retrieve test validation results if they exist 

63 if self.instance_set_test is not None: 

64 self.validation_results_test = [] 

65 for config in self.configurations: 

66 val_res = self.get_validation_data(self.instance_set_test, 

67 config) 

68 self.validation_results_test.append(val_res) 

69 

70 def get_configurations(self: ConfigurationOutput, config_path: Path) -> list[dict]: 

71 """Read all configurations and transform them to dictionaries.""" 

72 configs = [] 

73 # Check if the path exists and is a file 

74 if config_path.exists() and config_path.is_file(): 

75 with config_path.open("r") as file: 

76 for line in file: 

77 config = Solver.config_str_to_dict(line.strip()) 

78 if config not in configs: 

79 configs.append(config) 

80 return configs 

81 

82 def get_validation_data(self: ConfigurationOutput, instance_set: InstanceSet, 

83 config: dict) -> ConfigurationResults: 

84 """Returns best config and ConfigurationResults for instance set.""" 

85 objective = self.config_scenario.sparkle_objective 

86 

87 # Retrieve found configuration 

88 _, best_config = self.configurator.get_optimal_configuration( 

89 self.config_scenario) 

90 

91 # Retrieve validation results 

92 validator = Validator(self.directory) 

93 val_results = validator.get_validation_results( 

94 self.solver, instance_set, config=best_config, 

95 source_dir=self.directory, subdir="validation") 

96 header = val_results[0] 

97 results = [] 

98 value_column = header.index(objective.name) 

99 instance_column = header.index("Instance") 

100 status_column = header.index("Status") 

101 cpu_time_column = header.index("CPU Time") 

102 wall_time_column = header.index("Wallclock Time") 

103 for res in val_results[1:]: 

104 results.append([res[instance_column], SolverStatus(res[status_column]), 

105 res[value_column], res[cpu_time_column], 

106 res[wall_time_column]]) 

107 final_results = ValidationResults(self.solver, config, 

108 instance_set, results) 

109 perf_par = sgrfch.get_average_performance(val_results, 

110 objective) 

111 return ConfigurationResults(perf_par, 

112 final_results) 

113 

114 def serialize_configuration_results(self: ConfigurationOutput, 

115 cr: ConfigurationResults) -> dict: 

116 """Transform ConfigurationResults to dictionary format.""" 

117 return { 

118 "performance": cr.performance, 

119 "results": { 

120 "solver": cr.results.solver.name, 

121 "configuration": cr.results.configuration, 

122 "instance_set": cr.results.instance_set.name, 

123 "result_header": cr.results.result_header, 

124 "result_values": cr.results.result_vals, 

125 }, 

126 } 

127 

128 def write_output(self: ConfigurationOutput) -> None: 

129 """Write data into a JSON file.""" 

130 output_data = { 

131 "solver": self.solver.name if self.solver else None, 

132 "configurator": ( 

133 str(self.configurator) if self.configurator else None 

134 ), 

135 "best_configuration": Solver.config_str_to_dict(self.best_config), 

136 "configurations": self.configurations, 

137 "scenario": self.config_scenario.serialize() 

138 if self.configurator.scenario else None, 

139 "training_results": [ 

140 self.serialize_configuration_results(validation_result) 

141 for validation_result in self.validation_results 

142 ], 

143 "test_set": ( 

144 [ 

145 self.serialize_configuration_results(validation_result) 

146 for validation_result in self.validation_results_test 

147 ] 

148 if self.instance_set_test else None 

149 ), 

150 } 

151 

152 self.output.parent.mkdir(parents=True, exist_ok=True) 

153 with self.output.open("w") as f: 

154 json.dump(output_data, f, indent=4)