Coverage for sparkle/platform/output/selection_output.py: 98%
45 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 14:48 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 14:48 +0000
1#!/usr/bin/env python3
2"""Sparkle class to organise configuration output."""
4from __future__ import annotations
6from sparkle.structures import PerformanceDataFrame, FeatureDataFrame
7from sparkle.platform import generate_report_for_selection as sgfs
8from sparkle.types.objective import SparkleObjective
9from sparkle.instance import InstanceSet
10from sparkle.platform.output.structures import SelectionPerformance, SelectionSolverData
12import json
13from pathlib import Path
16class SelectionOutput:
17 """Class that collects selection data and outputs it a JSON format."""
19 def __init__(self: SelectionOutput, selection_scenario: Path,
20 train_data: PerformanceDataFrame,
21 feature_data: FeatureDataFrame,
22 training_instances: list[InstanceSet],
23 test_instances: list[InstanceSet],
24 objective: SparkleObjective,
25 cutoff_time: int,
26 output: Path) -> None:
27 """Initialize SelectionOutput class.
29 Args:
30 selection_scenario: Path to selection output directory
31 train_data: The performance input data for the selector
32 feature_data: Feature data created by extractor
33 training_instances: The set of training instances
34 test_instances: The set of test instances
35 objective: The objective of the selector
36 cutoff_time: The cutoff time
37 penalised_time: The penalised time
38 output: Path to the output directory
39 """
40 if not output.is_file():
41 self.output = output / "selection.json"
42 else:
43 self.output = output
44 if test_instances is not None and not isinstance(test_instances, list):
45 test_instances = [test_instances]
47 self.training_instances = training_instances
48 self.test_instances = test_instances
49 self.cutoff_time = cutoff_time
51 self.objective = objective
52 self.solver_data = self.get_solver_data(train_data, self.objective)
53 # Collect marginal contribution data
54 self.marginal_contribution_perfect = train_data.marginal_contribution(objective,
55 sort=True)
56 self.marginal_contribution_actual = \
57 sgfs.compute_selector_marginal_contribution(train_data,
58 feature_data,
59 selection_scenario,
60 objective)
62 # Collect performance data
63 portfolio_selector_performance_path = selection_scenario / "performance.csv"
64 vbs_performance = objective.instance_aggregator(
65 train_data.best_instance_performance(objective=objective.name))
66 self.performance_data = SelectionPerformance(
67 portfolio_selector_performance_path, vbs_performance, self.objective)
69 def get_solver_data(self: SelectionOutput,
70 train_data: PerformanceDataFrame,
71 objective: SparkleObjective) -> SelectionSolverData:
72 """Initalise SelectionSolverData object."""
73 solver_performance_ranking = train_data.get_solver_ranking(objective=objective)
74 num_solvers = train_data.num_solvers
75 return SelectionSolverData(solver_performance_ranking,
76 num_solvers)
78 def serialize_solvers(self: SelectionOutput,
79 sd: SelectionSolverData) -> dict:
80 """Transform SelectionSolverData to dictionary format."""
81 return {
82 "number_of_solvers": sd.num_solvers,
83 "single_best_solver": sd.single_best_solver,
84 "solver_ranking": [
85 {
86 "solver_name": solver[0],
87 "performance": solver[1]
88 }
89 for solver in sd.solver_performance_ranking
90 ]
91 }
93 def serialize_performance(self: SelectionOutput,
94 sp: SelectionPerformance) -> dict:
95 """Transform SelectionPerformance to dictionary format."""
96 return {
97 "vbs_performance": sp.vbs_performance,
98 "actual_performance": sp.actual_performance,
99 "objective": self.objective.name,
100 "metric": sp.metric
101 }
103 def serialize_instances(self: SelectionOutput,
104 instances: list[InstanceSet]) -> dict:
105 """Transform Instances to dictionary format."""
106 return {
107 "number_of_instance_sets": len(instances),
108 "instance_sets": [
109 {
110 "name": instance.name,
111 "number_of_instances": instance.size
112 }
113 for instance in instances
114 ]
115 }
117 def serialize_contribution(self: SelectionOutput) -> dict:
118 """Transform marginal contribution ranking to dictionary format."""
119 return {
120 "marginal_contribution_actual": [
121 {
122 "solver_name": ranking[0],
123 "marginal_contribution": ranking[1],
124 "best_performance": ranking[2]
125 }
126 for ranking in self.marginal_contribution_actual
127 ],
128 "marginal_contribution_perfect": [
129 {
130 "solver_name": ranking[0],
131 "marginal_contribution": ranking[1],
132 "best_performance": ranking[2]
133 }
134 for ranking in self.marginal_contribution_perfect
135 ]
136 }
138 def serialize_settings(self: SelectionOutput) -> dict:
139 """Transform settings to dictionary format."""
140 return {
141 "cutoff_time": self.cutoff_time,
142 }
144 def write_output(self: SelectionOutput) -> None:
145 """Write data into a JSON file."""
146 test_data = self.serialize_instances(self.test_instances) if self.test_instances\
147 else None
148 output_data = {
149 "solvers": self.serialize_solvers(self.solver_data),
150 "training_instances": self.serialize_instances(self.training_instances),
151 "test_instances": test_data,
152 "performance": self.serialize_performance(self.performance_data),
153 "settings": self.serialize_settings(),
154 "marginal_contribution": self.serialize_contribution()
155 }
157 self.output.parent.mkdir(parents=True, exist_ok=True)
158 with self.output.open("w") as f:
159 json.dump(output_data, f, indent=4)