Coverage for sparkle/platform/output/selection_output.py: 28%

58 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-01 13:21 +0000

1"""Sparkle class to organise configuration output.""" 

2from __future__ import annotations 

3 

4from sparkle.selector import SelectionScenario 

5from sparkle.structures import PerformanceDataFrame, FeatureDataFrame 

6# TODO: This dependency should be removed or the functionality should be moved 

7from sparkle.CLI.compute_marginal_contribution import \ 

8 compute_selector_marginal_contribution 

9from sparkle.platform.output.structures import SelectionPerformance, SelectionSolverData 

10 

11import json 

12from pathlib import Path 

13 

14 

15class SelectionOutput: 

16 """Class that collects selection data and outputs it a JSON format.""" 

17 

18 def __init__(self: SelectionOutput, 

19 selection_scenario: SelectionScenario, 

20 feature_data: FeatureDataFrame) -> None: 

21 """Initialize SelectionOutput class. 

22 

23 Args: 

24 selection_scenario: Path to selection output directory 

25 performance_data: The performance data used for the selector 

26 feature_data: The feature data used for the selector 

27 """ 

28 self.training_instances = selection_scenario.training_instances 

29 training_instance_sets = selection_scenario.training_instance_sets 

30 self.training_instance_sets =\ 

31 [(instance_set, sum(instance_set in s for s in self.training_instances)) 

32 for instance_set in training_instance_sets] 

33 self.test_instances = selection_scenario.test_instances 

34 test_sets = selection_scenario.test_instance_sets 

35 self.test_sets =\ 

36 [(instance_set, sum(instance_set in s for s in self.test_instances)) 

37 for instance_set in test_sets] 

38 self.cutoff_time = selection_scenario.solver_cutoff 

39 self.objective = selection_scenario.objective 

40 

41 solver_performance_data = selection_scenario.selector_performance_data.clone() 

42 solver_performance_data.remove_solver(SelectionScenario.__selector_solver_name__) 

43 

44 self.solver_performance_ranking =\ 

45 solver_performance_data.get_solver_ranking( 

46 instances=self.training_instances, 

47 objective=self.objective) 

48 

49 self.solver_data = self.get_solver_data(solver_performance_data) 

50 self.solvers = {} 

51 for solver_conf in selection_scenario.performance_data.columns: 

52 solver, conf = solver_conf.split("_", maxsplit=1) 

53 if solver not in self.solvers: 

54 self.solvers[solver] = [] 

55 self.solvers[solver].append(conf) 

56 

57 self.sbs_performance = solver_performance_data.get_value( 

58 solver=self.solver_performance_ranking[0][0], 

59 configuration=self.solver_performance_ranking[0][1], 

60 instance=self.training_instances, 

61 objective=self.objective.name) 

62 

63 # Collect marginal contribution data 

64 self.marginal_contribution_perfect =\ 

65 solver_performance_data.marginal_contribution( 

66 selection_scenario.objective, 

67 instances=self.training_instances, 

68 sort=True) 

69 

70 self.marginal_contribution_actual = \ 

71 compute_selector_marginal_contribution(feature_data, 

72 selection_scenario) 

73 # Collect performance data 

74 self.vbs_performance_data = solver_performance_data.best_instance_performance( 

75 instances=self.training_instances, 

76 objective=selection_scenario.objective) 

77 self.vbs_performance = selection_scenario.objective.instance_aggregator( 

78 self.vbs_performance_data) 

79 

80 self.test_set_performance = {} if self.test_sets else None 

81 for (test_set, _) in self.test_sets: 

82 test_set_instances = [instance for instance in self.test_instances 

83 if test_set in instance] 

84 test_perf = selection_scenario.selector_performance_data.best_performance( 

85 exclude_solvers=[ 

86 s for s in selection_scenario.selector_performance_data.solvers 

87 if s != SelectionScenario.__selector_solver_name__], 

88 instances=test_set_instances, 

89 objective=selection_scenario.objective 

90 ) 

91 self.test_set_performance[test_set] = test_perf 

92 self.actual_performance_data =\ 

93 selection_scenario.selector_performance_data.get_value( 

94 solver=SelectionScenario.__selector_solver_name__, 

95 instance=self.training_instances, 

96 objective=self.objective.name) 

97 self.actual_performance = self.objective.instance_aggregator( 

98 self.actual_performance_data) 

99 

100 def get_solver_data(self: SelectionOutput, 

101 train_data: PerformanceDataFrame) -> SelectionSolverData: 

102 """Initalise SelectionSolverData object.""" 

103 num_solvers = train_data.num_solvers 

104 return SelectionSolverData(self.solver_performance_ranking, 

105 num_solvers) 

106 

107 def serialise_solvers(self: SelectionOutput, 

108 sd: SelectionSolverData) -> dict: 

109 """Transform SelectionSolverData to dictionary format.""" 

110 return { 

111 "number_of_solvers": sd.num_solvers, 

112 "single_best_solver": sd.single_best_solver, 

113 "solver_ranking": [ 

114 { 

115 "solver_name": solver[0], 

116 "performance": solver[1] 

117 } 

118 for solver in sd.solver_performance_ranking 

119 ] 

120 } 

121 

122 def serialise_performance(self: SelectionOutput, 

123 sp: SelectionPerformance) -> dict: 

124 """Transform SelectionPerformance to dictionary format.""" 

125 return { 

126 "vbs_performance": sp.vbs_performance, 

127 "actual_performance": sp.actual_performance, 

128 "objective": self.objective.name, 

129 "metric": sp.metric 

130 } 

131 

132 def serialise_instances(self: SelectionOutput, 

133 instances: list[str]) -> dict: 

134 """Transform Instances to dictionary format.""" 

135 instance_sets = set(Path(instance).parent.name for instance in instances) 

136 return { 

137 "number_of_instance_sets": len(instance_sets), 

138 "instance_sets": [ 

139 { 

140 "name": instance_set, 

141 "number_of_instances": sum([1 if instance_set in instance else 0 

142 for instance in instances]) 

143 } 

144 for instance_set in instance_sets 

145 ] 

146 } 

147 

148 def serialise_marginal_contribution(self: SelectionOutput) -> dict: 

149 """Transform performance ranking to dictionary format.""" 

150 return { 

151 "marginal_contribution_actual": [ 

152 { 

153 "solver_name": ranking[0], 

154 "marginal_contribution": ranking[1], 

155 "best_performance": ranking[2] 

156 } 

157 for ranking in self.marginal_contribution_actual 

158 ], 

159 "marginal_contribution_perfect": [ 

160 { 

161 "solver_name": ranking[0], 

162 "marginal_contribution": ranking[1], 

163 "best_performance": ranking[2] 

164 } 

165 for ranking in self.marginal_contribution_perfect 

166 ] 

167 } 

168 

169 def serialise(self: SelectionOutput) -> dict: 

170 """Serialise the selection output.""" 

171 test_data = self.serialise_instances(self.test_instances) if self.test_instances\ 

172 else None 

173 return { 

174 "solvers": self.serialise_solvers(self.solver_data), 

175 "training_instances": self.serialise_instances(self.training_instances), 

176 "test_instances": test_data, 

177 "settings": {"cutoff_time": self.cutoff_time}, 

178 "marginal_contribution": self.serialise_marginal_contribution() 

179 } 

180 

181 def write_output(self: SelectionOutput, output: Path) -> None: 

182 """Write data into a JSON file.""" 

183 output = output / "configuration.json" if output.is_dir() else output 

184 with output.open("w") as f: 

185 json.dump(self.serialise(), f, indent=4)