Coverage for sparkle/CLI/configure_solver.py: 77%

128 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-07 15:22 +0000

1#!/usr/bin/env python3 

2"""Sparkle command to configure a solver.""" 

3from __future__ import annotations 

4import argparse 

5import sys 

6import math 

7 

8from runrunner import Runner 

9 

10from sparkle.CLI.help import global_variables as gv 

11from sparkle.CLI.help import logging as sl 

12from sparkle.CLI.initialise import check_for_initialise 

13from sparkle.CLI.help.reporting_scenario import Scenario 

14from sparkle.CLI.help.nicknames import resolve_object_name 

15from sparkle.CLI.help import argparse_custom as ac 

16 

17from sparkle.platform.settings_objects import SettingState 

18from sparkle.structures import PerformanceDataFrame, FeatureDataFrame 

19from sparkle.solver import Solver 

20from sparkle.instance import Instance_Set 

21 

22 

23def parser_function() -> argparse.ArgumentParser: 

24 """Define the command line arguments.""" 

25 parser = argparse.ArgumentParser( 

26 description="Configure a solver in the platform.", 

27 epilog=("Note that the test instance set is only used if the ``--ablation``" 

28 " or ``--validation`` flags are given")) 

29 parser.add_argument(*ac.ConfiguratorArgument.names, 

30 **ac.ConfiguratorArgument.kwargs) 

31 parser.add_argument(*ac.SolverArgument.names, 

32 **ac.SolverArgument.kwargs) 

33 parser.add_argument(*ac.InstanceSetTrainArgument.names, 

34 **ac.InstanceSetTrainArgument.kwargs) 

35 parser.add_argument(*ac.InstanceSetTestArgument.names, 

36 **ac.InstanceSetTestArgument.kwargs) 

37 parser.add_argument(*ac.TestSetRunAllConfigurationArgument.names, 

38 **ac.TestSetRunAllConfigurationArgument.kwargs) 

39 parser.add_argument(*ac.ObjectivesArgument.names, 

40 **ac.ObjectivesArgument.kwargs) 

41 parser.add_argument(*ac.TargetCutOffTimeArgument.names, 

42 **ac.TargetCutOffTimeArgument.kwargs) 

43 parser.add_argument(*ac.SolverCallsArgument.names, 

44 **ac.SolverCallsArgument.kwargs) 

45 parser.add_argument(*ac.NumberOfRunsConfigurationArgument.names, 

46 **ac.NumberOfRunsConfigurationArgument.kwargs) 

47 parser.add_argument(*ac.SettingsFileArgument.names, 

48 **ac.SettingsFileArgument.kwargs) 

49 parser.add_argument(*ac.UseFeaturesArgument.names, 

50 **ac.UseFeaturesArgument.kwargs) 

51 parser.add_argument(*ac.RunOnArgument.names, 

52 **ac.RunOnArgument.kwargs) 

53 return parser 

54 

55 

56def apply_settings_from_args(args: argparse.Namespace) -> None: 

57 """Apply command line arguments to settings. 

58 

59 Args: 

60 args: Arguments object created by ArgumentParser. 

61 """ 

62 if args.settings_file is not None: 

63 gv.settings().read_settings_ini(args.settings_file, SettingState.CMD_LINE) 

64 if args.configurator is not None: 

65 gv.settings().set_general_sparkle_configurator( 

66 args.configurator, SettingState.CMD_LINE) 

67 if args.objectives is not None: 

68 gv.settings().set_general_sparkle_objectives( 

69 args.objectives, SettingState.CMD_LINE) 

70 if args.target_cutoff_time is not None: 

71 gv.settings().set_general_target_cutoff_time( 

72 args.target_cutoff_time, SettingState.CMD_LINE) 

73 if args.solver_calls is not None: 

74 gv.settings().set_configurator_solver_calls( 

75 args.solver_calls, SettingState.CMD_LINE) 

76 if args.number_of_runs is not None: 

77 gv.settings().set_configurator_number_of_runs( 

78 args.number_of_runs, SettingState.CMD_LINE) 

79 if args.run_on is not None: 

80 gv.settings().set_run_on( 

81 args.run_on.value, SettingState.CMD_LINE) 

82 

83 

84def main(argv: list[str]) -> None: 

85 """Main function of the configure solver command.""" 

86 # Log command call 

87 sl.log_command(sys.argv) 

88 check_for_initialise() 

89 

90 parser = parser_function() 

91 

92 # Process command line arguments 

93 args = parser.parse_args(argv) 

94 

95 apply_settings_from_args(args) 

96 

97 solver: Solver = resolve_object_name( 

98 args.solver, 

99 gv.file_storage_data_mapping[gv.solver_nickname_list_path], 

100 gv.settings().DEFAULT_solver_dir, class_name=Solver) 

101 if solver is None: 

102 raise ValueError(f"Solver {args.solver} not found.") 

103 instance_set_train = resolve_object_name( 

104 args.instance_set_train, 

105 gv.file_storage_data_mapping[gv.instances_nickname_path], 

106 gv.settings().DEFAULT_instance_dir, Instance_Set) 

107 if instance_set_train is None: 

108 raise ValueError(f"Instance set {args.instance_set_train} not found.") 

109 instance_set_test = args.instance_set_test 

110 if instance_set_test is not None: 

111 instance_set_test = resolve_object_name( 

112 args.instance_set_test, 

113 gv.file_storage_data_mapping[gv.instances_nickname_path], 

114 gv.settings().DEFAULT_instance_dir, Instance_Set) 

115 use_features = args.use_features 

116 run_on = gv.settings().get_run_on() 

117 

118 configurator = gv.settings().get_general_sparkle_configurator() 

119 configurator_settings = gv.settings().get_configurator_settings(configurator.name) 

120 

121 sparkle_objectives =\ 

122 gv.settings().get_general_sparkle_objectives() 

123 configurator_runs = gv.settings().get_configurator_number_of_runs() 

124 performance_data = PerformanceDataFrame(gv.settings().DEFAULT_performance_data_path) 

125 

126 # Check if given objectives are in the data frame 

127 for objective in sparkle_objectives: 

128 if objective.name not in performance_data.objective_names: 

129 print(f"WARNING: Objective {objective.name} not found in performance data. " 

130 "Adding to data frame.") 

131 performance_data.add_objective(objective.name) 

132 

133 if use_features: 

134 feature_data = FeatureDataFrame(gv.settings().DEFAULT_feature_data_path) 

135 # Check that the train instance set is in the feature data frame 

136 invalid = False 

137 remaining_instance_jobs =\ 

138 set([instance for instance, _, _ in feature_data.remaining_jobs()]) 

139 for instance in instance_set_train.instance_paths: 

140 if str(instance) not in feature_data.instances: 

141 print(f"ERROR: Train Instance {instance} not found in feature data.") 

142 invalid = True 

143 elif instance in remaining_instance_jobs: # Check jobs 

144 print(f"ERROR: Features have not been computed for instance {instance}.") 

145 invalid = True 

146 if invalid: 

147 sys.exit(-1) 

148 configurator_settings.update({"feature_data": feature_data}) 

149 

150 sbatch_options = gv.settings().get_slurm_extra_options(as_args=True) 

151 config_scenario = configurator.scenario_class()( 

152 solver, instance_set_train, sparkle_objectives, 

153 configurator.output_path, **configurator_settings) 

154 

155 # Run the default configuration 

156 remaining_jobs = performance_data.get_job_list() 

157 relevant_jobs = [] 

158 for instance, run_id, solver_id in remaining_jobs: 

159 # NOTE: This run_id skip will not work if we do multiple runs per configuration 

160 if run_id != 1 or solver_id != str(solver.directory): 

161 continue 

162 configuration = performance_data.get_value( 

163 solver_id, instance, sparkle_objectives[0].name, run=run_id, 

164 solver_fields=[PerformanceDataFrame.column_configuration]) 

165 # Only run jobs with the default configuration 

166 if not isinstance(configuration, str) and math.isnan(configuration): 

167 relevant_jobs.append((instance, run_id, solver_id)) 

168 

169 # Expand the performance dataframe so it can store the configuration 

170 performance_data.add_runs(configurator_runs, 

171 instance_names=[ 

172 str(i) for i in instance_set_train.instance_paths]) 

173 if instance_set_test is not None: 

174 # Expand the performance dataframe so it can store the test set results of the 

175 # found configurations 

176 test_set_runs = configurator_runs if args.test_set_run_all_configurations else 1 

177 performance_data.add_runs( 

178 test_set_runs, 

179 instance_names=[str(i) for i in instance_set_test.instance_paths]) 

180 performance_data.save_csv() 

181 

182 dependency_job_list = configurator.configure( 

183 scenario=config_scenario, 

184 data_target=performance_data, 

185 sbatch_options=sbatch_options, 

186 num_parallel_jobs=gv.settings().get_number_of_jobs_in_parallel(), 

187 base_dir=sl.caller_log_dir, 

188 run_on=run_on) 

189 

190 # If we have default configurations that need to be run, schedule them too 

191 if len(relevant_jobs) > 0: 

192 instances = [job[0] for job in relevant_jobs] 

193 runs = list(set([job[1] for job in relevant_jobs])) 

194 default_job = solver.run_performance_dataframe( 

195 instances, runs, performance_data, 

196 sbatch_options=sbatch_options, 

197 cutoff_time=config_scenario.cutoff_time, 

198 log_dir=config_scenario.validation, 

199 base_dir=sl.caller_log_dir, 

200 job_name=f"Default Configuration: {solver.name} Validation on " 

201 f"{instance_set_train.name}", 

202 run_on=run_on) 

203 dependency_job_list.append(default_job) 

204 

205 # Update latest scenario 

206 gv.latest_scenario().set_config_solver(solver) 

207 gv.latest_scenario().set_config_instance_set_train(instance_set_train.directory) 

208 gv.latest_scenario().set_configuration_scenario(config_scenario.scenario_file_path) 

209 gv.latest_scenario().set_latest_scenario(Scenario.CONFIGURATION) 

210 

211 if instance_set_test is not None: 

212 gv.latest_scenario().set_config_instance_set_test(instance_set_test.directory) 

213 # Schedule test set jobs 

214 if args.test_set_run_all_configurations: 

215 # TODO: Schedule test set runs for all configurations 

216 print("Running all configurations on test set is not implemented yet.") 

217 pass 

218 else: 

219 # We place the results in the index we just added 

220 run_index = list(set([performance_data.get_instance_num_runs(str(i)) 

221 for i in instance_set_test.instance_paths])) 

222 test_set_job = solver.run_performance_dataframe( 

223 instance_set_test, 

224 run_index, 

225 performance_data, 

226 cutoff_time=config_scenario.cutoff_time, 

227 objective=config_scenario.sparkle_objective, 

228 train_set=instance_set_train, 

229 sbatch_options=sbatch_options, 

230 log_dir=config_scenario.validation, 

231 base_dir=sl.caller_log_dir, 

232 dependencies=dependency_job_list, 

233 job_name=f"Best Configuration: {solver.name} Validation on " 

234 f"{instance_set_test.name}", 

235 run_on=run_on) 

236 dependency_job_list.append(test_set_job) 

237 else: 

238 # Set to default to overwrite possible old path 

239 gv.latest_scenario().set_config_instance_set_test() 

240 

241 if run_on == Runner.SLURM: 

242 job_id_str = ",".join([run.run_id for run in dependency_job_list]) 

243 print(f"Running configuration. Waiting for Slurm job(s) with id(s): " 

244 f"{job_id_str}") 

245 else: 

246 print("Running configuration finished!") 

247 

248 # Write used settings to file 

249 gv.settings().write_used_settings() 

250 # Write used scenario to file 

251 gv.latest_scenario().write_scenario_ini() 

252 sys.exit(0) 

253 

254 

255if __name__ == "__main__": 

256 main(sys.argv[1:])