Coverage for sparkle/CLI/configure_solver.py: 76%

131 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-03 10:42 +0000

1#!/usr/bin/env python3 

2"""Sparkle command to configure a solver.""" 

3from __future__ import annotations 

4import argparse 

5import sys 

6import math 

7 

8from runrunner import Runner 

9 

10from sparkle.CLI.help import global_variables as gv 

11from sparkle.CLI.help import logging as sl 

12from sparkle.CLI.initialise import check_for_initialise 

13from sparkle.CLI.help.reporting_scenario import Scenario 

14from sparkle.CLI.help.nicknames import resolve_object_name 

15from sparkle.CLI.help import argparse_custom as ac 

16 

17from sparkle.platform.settings_objects import SettingState 

18from sparkle.structures import PerformanceDataFrame, FeatureDataFrame 

19from sparkle.solver import Solver 

20from sparkle.instance import Instance_Set 

21 

22 

23def parser_function() -> argparse.ArgumentParser: 

24 """Define the command line arguments.""" 

25 parser = argparse.ArgumentParser( 

26 description="Configure a solver in the platform.", 

27 epilog=("Note that the test instance set is only used if the ``--ablation``" 

28 " or ``--validation`` flags are given")) 

29 parser.add_argument(*ac.ConfiguratorArgument.names, 

30 **ac.ConfiguratorArgument.kwargs) 

31 parser.add_argument(*ac.SolverArgument.names, 

32 **ac.SolverArgument.kwargs) 

33 parser.add_argument(*ac.InstanceSetTrainArgument.names, 

34 **ac.InstanceSetTrainArgument.kwargs) 

35 parser.add_argument(*ac.InstanceSetTestArgument.names, 

36 **ac.InstanceSetTestArgument.kwargs) 

37 parser.add_argument(*ac.TestSetRunAllConfigurationArgument.names, 

38 **ac.TestSetRunAllConfigurationArgument.kwargs) 

39 parser.add_argument(*ac.ObjectivesArgument.names, 

40 **ac.ObjectivesArgument.kwargs) 

41 parser.add_argument(*ac.TargetCutOffTimeArgument.names, 

42 **ac.TargetCutOffTimeArgument.kwargs) 

43 parser.add_argument(*ac.SolverCallsArgument.names, 

44 **ac.SolverCallsArgument.kwargs) 

45 parser.add_argument(*ac.NumberOfRunsConfigurationArgument.names, 

46 **ac.NumberOfRunsConfigurationArgument.kwargs) 

47 parser.add_argument(*ac.SettingsFileArgument.names, 

48 **ac.SettingsFileArgument.kwargs) 

49 parser.add_argument(*ac.UseFeaturesArgument.names, 

50 **ac.UseFeaturesArgument.kwargs) 

51 parser.add_argument(*ac.RunOnArgument.names, 

52 **ac.RunOnArgument.kwargs) 

53 return parser 

54 

55 

56def apply_settings_from_args(args: argparse.Namespace) -> None: 

57 """Apply command line arguments to settings. 

58 

59 Args: 

60 args: Arguments object created by ArgumentParser. 

61 """ 

62 if args.settings_file is not None: 

63 gv.settings().read_settings_ini(args.settings_file, SettingState.CMD_LINE) 

64 if args.configurator is not None: 

65 gv.settings().set_general_sparkle_configurator( 

66 args.configurator, SettingState.CMD_LINE) 

67 if args.objectives is not None: 

68 gv.settings().set_general_sparkle_objectives( 

69 args.objectives, SettingState.CMD_LINE) 

70 if args.target_cutoff_time is not None: 

71 gv.settings().set_general_target_cutoff_time( 

72 args.target_cutoff_time, SettingState.CMD_LINE) 

73 if args.solver_calls is not None: 

74 gv.settings().set_configurator_solver_calls( 

75 args.solver_calls, SettingState.CMD_LINE) 

76 if args.number_of_runs is not None: 

77 gv.settings().set_configurator_number_of_runs( 

78 args.number_of_runs, SettingState.CMD_LINE) 

79 if args.run_on is not None: 

80 gv.settings().set_run_on( 

81 args.run_on.value, SettingState.CMD_LINE) 

82 

83 

84def main(argv: list[str]) -> None: 

85 """Main function of the configure solver command.""" 

86 # Log command call 

87 sl.log_command(sys.argv) 

88 check_for_initialise() 

89 

90 parser = parser_function() 

91 

92 # Process command line arguments 

93 args = parser.parse_args(argv) 

94 

95 apply_settings_from_args(args) 

96 

97 solver: Solver = resolve_object_name( 

98 args.solver, 

99 gv.file_storage_data_mapping[gv.solver_nickname_list_path], 

100 gv.settings().DEFAULT_solver_dir, class_name=Solver) 

101 if solver is None: 

102 raise ValueError(f"Solver {args.solver} not found.") 

103 instance_set_train = resolve_object_name( 

104 args.instance_set_train, 

105 gv.file_storage_data_mapping[gv.instances_nickname_path], 

106 gv.settings().DEFAULT_instance_dir, Instance_Set) 

107 if instance_set_train is None: 

108 raise ValueError(f"Instance set {args.instance_set_train} not found.") 

109 instance_set_test = args.instance_set_test 

110 if instance_set_test is not None: 

111 instance_set_test = resolve_object_name( 

112 args.instance_set_test, 

113 gv.file_storage_data_mapping[gv.instances_nickname_path], 

114 gv.settings().DEFAULT_instance_dir, Instance_Set) 

115 use_features = args.use_features 

116 run_on = gv.settings().get_run_on() 

117 

118 configurator = gv.settings().get_general_sparkle_configurator() 

119 configurator_settings = gv.settings().get_configurator_settings(configurator.name) 

120 

121 sparkle_objectives =\ 

122 gv.settings().get_general_sparkle_objectives() 

123 if len(sparkle_objectives) > 1: 

124 print(f"WARNING: {configurator.name} does not have multi objective support. " 

125 f"Only the first objective ({sparkle_objectives[0]}) will be optimised.") 

126 configurator_runs = gv.settings().get_configurator_number_of_runs() 

127 performance_data = PerformanceDataFrame(gv.settings().DEFAULT_performance_data_path) 

128 

129 # Check if given objectives are in the data frame 

130 for objective in sparkle_objectives: 

131 if objective.name not in performance_data.objective_names: 

132 print(f"WARNING: Objective {objective.name} not found in performance data. " 

133 "Adding to data frame.") 

134 performance_data.add_objective(objective.name) 

135 

136 if use_features: 

137 feature_data = FeatureDataFrame(gv.settings().DEFAULT_feature_data_path) 

138 # Check that the train instance set is in the feature data frame 

139 invalid = False 

140 remaining_instance_jobs =\ 

141 set([instance for instance, _, _ in feature_data.remaining_jobs()]) 

142 for instance in instance_set_train.instance_paths: 

143 if str(instance) not in feature_data.instances: 

144 print(f"ERROR: Train Instance {instance} not found in feature data.") 

145 invalid = True 

146 elif instance in remaining_instance_jobs: # Check jobs 

147 print(f"ERROR: Features have not been computed for instance {instance}.") 

148 invalid = True 

149 if invalid: 

150 sys.exit(-1) 

151 configurator_settings.update({"feature_data": feature_data}) 

152 

153 sbatch_options = gv.settings().get_slurm_extra_options(as_args=True) 

154 slurm_prepend = gv.settings().get_slurm_job_prepend() 

155 config_scenario = configurator.scenario_class()( 

156 solver, instance_set_train, sparkle_objectives, 

157 configurator.output_path, **configurator_settings) 

158 

159 # Run the default configuration 

160 remaining_jobs = performance_data.get_job_list() 

161 relevant_jobs = [] 

162 for instance, run_id, solver_id in remaining_jobs: 

163 # NOTE: This run_id skip will not work if we do multiple runs per configuration 

164 if run_id != 1 or solver_id != str(solver.directory): 

165 continue 

166 configuration = performance_data.get_value( 

167 solver_id, instance, sparkle_objectives[0].name, run=run_id, 

168 solver_fields=[PerformanceDataFrame.column_configuration]) 

169 # Only run jobs with the default configuration 

170 if not isinstance(configuration, str) and math.isnan(configuration): 

171 relevant_jobs.append((instance, run_id, solver_id)) 

172 

173 # Expand the performance dataframe so it can store the configuration 

174 performance_data.add_runs(configurator_runs, 

175 instance_names=[ 

176 str(i) for i in instance_set_train.instance_paths], 

177 initial_values=[PerformanceDataFrame.missing_value, 

178 PerformanceDataFrame.missing_value, 

179 {}]) 

180 if instance_set_test is not None: 

181 # Expand the performance dataframe so it can store the test set results of the 

182 # found configurations 

183 test_set_runs = configurator_runs if args.test_set_run_all_configurations else 1 

184 performance_data.add_runs( 

185 test_set_runs, 

186 instance_names=[str(i) for i in instance_set_test.instance_paths]) 

187 performance_data.save_csv() 

188 

189 dependency_job_list = configurator.configure( 

190 scenario=config_scenario, 

191 data_target=performance_data, 

192 sbatch_options=sbatch_options, 

193 slurm_prepend=slurm_prepend, 

194 num_parallel_jobs=gv.settings().get_number_of_jobs_in_parallel(), 

195 base_dir=sl.caller_log_dir, 

196 run_on=run_on) 

197 

198 # If we have default configurations that need to be run, schedule them too 

199 if len(relevant_jobs) > 0: 

200 instances = [job[0] for job in relevant_jobs] 

201 runs = list(set([job[1] for job in relevant_jobs])) 

202 default_job = solver.run_performance_dataframe( 

203 instances, runs, performance_data, 

204 sbatch_options=sbatch_options, 

205 slurm_prepend=slurm_prepend, 

206 cutoff_time=config_scenario.cutoff_time, 

207 log_dir=config_scenario.validation, 

208 base_dir=sl.caller_log_dir, 

209 job_name=f"Default Configuration: {solver.name} Validation on " 

210 f"{instance_set_train.name}", 

211 run_on=run_on) 

212 dependency_job_list.append(default_job) 

213 

214 # Update latest scenario 

215 gv.latest_scenario().set_config_solver(solver) 

216 gv.latest_scenario().set_config_instance_set_train(instance_set_train.directory) 

217 gv.latest_scenario().set_configuration_scenario(config_scenario.scenario_file_path) 

218 gv.latest_scenario().set_latest_scenario(Scenario.CONFIGURATION) 

219 

220 if instance_set_test is not None: 

221 gv.latest_scenario().set_config_instance_set_test(instance_set_test.directory) 

222 # Schedule test set jobs 

223 if args.test_set_run_all_configurations: 

224 # TODO: Schedule test set runs for all configurations 

225 print("Running all configurations on test set is not implemented yet.") 

226 pass 

227 else: 

228 # We place the results in the index we just added 

229 run_index = list(set([performance_data.get_instance_num_runs(str(i)) 

230 for i in instance_set_test.instance_paths])) 

231 test_set_job = solver.run_performance_dataframe( 

232 instance_set_test, 

233 run_index, 

234 performance_data, 

235 cutoff_time=config_scenario.cutoff_time, 

236 objective=config_scenario.sparkle_objective, 

237 train_set=instance_set_train, 

238 sbatch_options=sbatch_options, 

239 log_dir=config_scenario.validation, 

240 base_dir=sl.caller_log_dir, 

241 dependencies=dependency_job_list, 

242 job_name=f"Best Configuration: {solver.name} Validation on " 

243 f"{instance_set_test.name}", 

244 run_on=run_on) 

245 dependency_job_list.append(test_set_job) 

246 else: 

247 # Set to default to overwrite possible old path 

248 gv.latest_scenario().set_config_instance_set_test() 

249 

250 if run_on == Runner.SLURM: 

251 job_id_str = ",".join([run.run_id for run in dependency_job_list]) 

252 print(f"Running configuration. Waiting for Slurm job(s) with id(s): " 

253 f"{job_id_str}") 

254 else: 

255 print("Running configuration finished!") 

256 

257 # Write used settings to file 

258 gv.settings().write_used_settings() 

259 # Write used scenario to file 

260 gv.latest_scenario().write_scenario_ini() 

261 sys.exit(0) 

262 

263 

264if __name__ == "__main__": 

265 main(sys.argv[1:])