Coverage for sparkle/platform/output/configuration_output.py: 99%
68 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 14:48 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 14:48 +0000
1#!/usr/bin/env python3
2"""Sparkle class to organise configuration output."""
4from __future__ import annotations
6from sparkle.platform import \
7 generate_report_for_configuration as sgrfch
8from sparkle.solver import Solver
9from sparkle.instance import InstanceSet
10from sparkle.configurator.configurator import Configurator, ConfigurationScenario
11from sparkle.solver.validator import Validator
12from sparkle.platform.output.structures import ValidationResults, ConfigurationResults
13from sparkle.types import SolverStatus
15import json
16from pathlib import Path
19class ConfigurationOutput:
20 """Class that collects configuration data and outputs it a JSON format."""
22 def __init__(self: ConfigurationOutput, path: Path,
23 configurator: Configurator, config_scenario: ConfigurationScenario,
24 instance_set_test: InstanceSet, output: Path) -> None:
25 """Initialize Configurator Output class.
27 Args:
28 path: Path to configuration output directory
29 configurator: The configurator that was used
30 config_scenario: The scenario to output
31 instance_set_test: Instance set used for testing
32 output: Path to the output directory
33 """
34 self.solver = config_scenario.solver
35 self.configurator = configurator
36 self.instance_set_train = config_scenario.instance_set
37 self.instance_set_test = instance_set_test
38 self.directory = path
39 self.config_scenario = config_scenario
40 self.output = output / "configuration.json" if not output.is_file() else output
42 solver_dir_name = path.name
43 scenario_file = path / f"{solver_dir_name}_scenario.txt"
44 if not scenario_file.is_file():
45 raise Exception("Can't find scenario file")
47 # Retrieve all configurations
48 config_path = path / "validation" / "configurations.csv"
49 self.configurations = self.get_configurations(config_path)
51 # Retrieve best found configuration
52 _, self.best_config = self.configurator.get_optimal_configuration(
53 self.config_scenario)
55 # Retrieves validation results for all configurations
56 self.validation_results = []
57 for config in self.configurations:
58 val_res = self.get_validation_data(self.instance_set_train,
59 config)
60 self.validation_results.append(val_res)
62 # Retrieve test validation results if they exist
63 if self.instance_set_test is not None:
64 self.validation_results_test = []
65 for config in self.configurations:
66 val_res = self.get_validation_data(self.instance_set_test,
67 config)
68 self.validation_results_test.append(val_res)
70 def get_configurations(self: ConfigurationOutput, config_path: Path) -> list[dict]:
71 """Read all configurations and transform them to dictionaries."""
72 configs = []
73 # Check if the path exists and is a file
74 if config_path.exists() and config_path.is_file():
75 with config_path.open("r") as file:
76 for line in file:
77 config = Solver.config_str_to_dict(line.strip())
78 if config not in configs:
79 configs.append(config)
80 return configs
82 def get_validation_data(self: ConfigurationOutput, instance_set: InstanceSet,
83 config: dict) -> ConfigurationResults:
84 """Returns best config and ConfigurationResults for instance set."""
85 objective = self.config_scenario.sparkle_objective
87 # Retrieve found configuration
88 _, best_config = self.configurator.get_optimal_configuration(
89 self.config_scenario)
91 # Retrieve validation results
92 validator = Validator(self.directory)
93 val_results = validator.get_validation_results(
94 self.solver, instance_set, config=best_config,
95 source_dir=self.directory, subdir="validation")
96 header = val_results[0]
97 results = []
98 value_column = header.index(objective.name)
99 instance_column = header.index("Instance")
100 status_column = header.index("Status")
101 cpu_time_column = header.index("CPU Time")
102 wall_time_column = header.index("Wallclock Time")
103 for res in val_results[1:]:
104 results.append([res[instance_column], SolverStatus(res[status_column]),
105 res[value_column], res[cpu_time_column],
106 res[wall_time_column]])
107 final_results = ValidationResults(self.solver, config,
108 instance_set, results)
109 perf_par = sgrfch.get_average_performance(val_results,
110 objective)
111 return ConfigurationResults(perf_par,
112 final_results)
114 def serialize_configuration_results(self: ConfigurationOutput,
115 cr: ConfigurationResults) -> dict:
116 """Transform ConfigurationResults to dictionary format."""
117 return {
118 "performance": cr.performance,
119 "results": {
120 "solver": cr.results.solver.name,
121 "configuration": cr.results.configuration,
122 "instance_set": cr.results.instance_set.name,
123 "result_header": cr.results.result_header,
124 "result_values": cr.results.result_vals,
125 },
126 }
128 def write_output(self: ConfigurationOutput) -> None:
129 """Write data into a JSON file."""
130 output_data = {
131 "solver": self.solver.name if self.solver else None,
132 "configurator": (
133 str(self.configurator) if self.configurator else None
134 ),
135 "best_configuration": Solver.config_str_to_dict(self.best_config),
136 "configurations": self.configurations,
137 "scenario": self.config_scenario.serialize()
138 if self.configurator.scenario else None,
139 "training_results": [
140 self.serialize_configuration_results(validation_result)
141 for validation_result in self.validation_results
142 ],
143 "test_set": (
144 [
145 self.serialize_configuration_results(validation_result)
146 for validation_result in self.validation_results_test
147 ]
148 if self.instance_set_test else None
149 ),
150 }
152 self.output.parent.mkdir(parents=True, exist_ok=True)
153 with self.output.open("w") as f:
154 json.dump(output_data, f, indent=4)