Coverage for sparkle/CLI/configure_solver.py: 76%

128 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 14:48 +0000

1#!/usr/bin/env python3 

2"""Sparkle command to configure a solver.""" 

3from __future__ import annotations 

4 

5import argparse 

6import sys 

7import os 

8from pathlib import Path 

9from pandas import DataFrame 

10 

11from runrunner.base import Runner, Run 

12import runrunner as rrr 

13 

14from sparkle.CLI.help import global_variables as gv 

15from sparkle.CLI.help import logging as sl 

16from sparkle.platform.settings_objects import SettingState 

17from sparkle.CLI.help.reporting_scenario import Scenario 

18from sparkle.structures import FeatureDataFrame 

19from sparkle.platform import CommandName, COMMAND_DEPENDENCIES 

20from sparkle.configurator import implementations as configurator_implementations 

21from sparkle.CLI.help.nicknames import resolve_object_name 

22from sparkle.solver import Solver 

23from sparkle.CLI.initialise import check_for_initialise 

24from sparkle.CLI.help import argparse_custom as ac 

25from sparkle.instance import Instance_Set, InstanceSet 

26 

27 

28def parser_function() -> argparse.ArgumentParser: 

29 """Define the command line arguments.""" 

30 parser = argparse.ArgumentParser( 

31 description="Configure a solver in the platform.", 

32 epilog=("Note that the test instance set is only used if the ``--ablation``" 

33 " or ``--validation`` flags are given")) 

34 parser.add_argument(*ac.ConfiguratorArgument.names, 

35 **ac.ConfiguratorArgument.kwargs) 

36 parser.add_argument(*ac.SolverArgument.names, 

37 **ac.SolverArgument.kwargs) 

38 parser.add_argument(*ac.InstanceSetTrainArgument.names, 

39 **ac.InstanceSetTrainArgument.kwargs) 

40 parser.add_argument(*ac.InstanceSetTestArgument.names, 

41 **ac.InstanceSetTestArgument.kwargs) 

42 parser.add_argument(*ac.SparkleObjectiveArgument.names, 

43 **ac.SparkleObjectiveArgument.kwargs) 

44 parser.add_argument(*ac.TargetCutOffTimeConfigurationArgument.names, 

45 **ac.TargetCutOffTimeConfigurationArgument.kwargs) 

46 parser.add_argument(*ac.SolverCallsArgument.names, 

47 **ac.SolverCallsArgument.kwargs) 

48 parser.add_argument(*ac.NumberOfRunsConfigurationArgument.names, 

49 **ac.NumberOfRunsConfigurationArgument.kwargs) 

50 parser.add_argument(*ac.SettingsFileArgument.names, 

51 **ac.SettingsFileArgument.kwargs) 

52 parser.add_argument(*ac.UseFeaturesArgument.names, 

53 **ac.UseFeaturesArgument.kwargs) 

54 parser.add_argument(*ac.ValidateArgument.names, 

55 **ac.ValidateArgument.kwargs) 

56 parser.add_argument(*ac.AblationArgument.names, 

57 **ac.AblationArgument.kwargs) 

58 parser.add_argument(*ac.RunOnArgument.names, 

59 **ac.RunOnArgument.kwargs) 

60 return parser 

61 

62 

63def apply_settings_from_args(args: argparse.Namespace) -> None: 

64 """Apply command line arguments to settings. 

65 

66 Args: 

67 args: Arguments object created by ArgumentParser. 

68 """ 

69 if args.settings_file is not None: 

70 gv.settings().read_settings_ini(args.settings_file, SettingState.CMD_LINE) 

71 if args.configurator is not None: 

72 gv.settings().set_general_sparkle_configurator( 

73 args.configurator, SettingState.CMD_LINE) 

74 if args.objectives is not None: 

75 gv.settings().set_general_sparkle_objectives( 

76 args.objectives, SettingState.CMD_LINE) 

77 if args.target_cutoff_time is not None: 

78 gv.settings().set_general_target_cutoff_time( 

79 args.target_cutoff_time, SettingState.CMD_LINE) 

80 if args.solver_calls is not None: 

81 gv.settings().set_configurator_solver_calls( 

82 args.solver_calls, SettingState.CMD_LINE) 

83 if args.number_of_runs is not None: 

84 gv.settings().set_configurator_number_of_runs( 

85 args.number_of_runs, SettingState.CMD_LINE) 

86 if args.run_on is not None: 

87 gv.settings().set_run_on( 

88 args.run_on.value, SettingState.CMD_LINE) 

89 

90 

91def run_after(solver: Path, 

92 train_set: InstanceSet, 

93 test_set: InstanceSet, 

94 dependency: list[Run], 

95 command: CommandName, 

96 run_on: Runner = Runner.SLURM) -> Run: 

97 """Add a command to run after configuration to RunRunner queue. 

98 

99 Args: 

100 solver: Path (object) to solver. 

101 train_set: Instances used for training. 

102 test_set: Instances used for testing. 

103 dependency: List of job dependencies. 

104 command: The command to run. Currently supported: Validation and Ablation. 

105 run_on: Whether the job is executed on Slurm or locally. 

106 

107 Returns: 

108 RunRunner Run object regarding the callback 

109 """ 

110 cmd_file = "validate_configured_vs_default.py" 

111 if command == CommandName.RUN_ABLATION: 

112 cmd_file = "run_ablation.py" 

113 

114 command_line = f"./sparkle/CLI/{cmd_file} --settings-file Settings/latest.ini "\ 

115 f"--solver {solver.name} --instance-set-train {train_set.directory}"\ 

116 f" --run-on {run_on}" 

117 if test_set is not None: 

118 command_line += f" --instance-set-test {test_set.directory}" 

119 

120 run = rrr.add_to_queue( 

121 runner=run_on, 

122 cmd=command_line, 

123 name=command, 

124 dependencies=dependency, 

125 base_dir=sl.caller_log_dir, 

126 srun_options=["-N1", "-n1"], 

127 sbatch_options=gv.settings().get_slurm_extra_options(as_args=True)) 

128 

129 if run_on == Runner.LOCAL: 

130 print("Waiting for the local calculations to finish.") 

131 run.wait() 

132 return run 

133 

134 

135def main(argv: list[str]) -> None: 

136 """Main function of the configure solver command.""" 

137 # Log command call 

138 sl.log_command(sys.argv) 

139 

140 parser = parser_function() 

141 

142 # Process command line arguments 

143 args = parser.parse_args(argv) 

144 

145 apply_settings_from_args(args) 

146 

147 validate = args.validate 

148 ablation = args.ablation 

149 solver = resolve_object_name( 

150 args.solver, 

151 gv.file_storage_data_mapping[gv.solver_nickname_list_path], 

152 gv.settings().DEFAULT_solver_dir, class_name=Solver) 

153 if solver is None: 

154 raise ValueError(f"Solver {args.solver} not found.") 

155 instance_set_train = resolve_object_name( 

156 args.instance_set_train, 

157 gv.file_storage_data_mapping[gv.instances_nickname_path], 

158 gv.settings().DEFAULT_instance_dir, Instance_Set) 

159 if instance_set_train is None: 

160 raise ValueError(f"Instance set {args.instance_set_train} not found.") 

161 instance_set_test = args.instance_set_test 

162 if instance_set_test is not None: 

163 instance_set_test = resolve_object_name( 

164 args.instance_set_test, 

165 gv.file_storage_data_mapping[gv.instances_nickname_path], 

166 gv.settings().DEFAULT_instance_dir, Instance_Set) 

167 use_features = args.use_features 

168 run_on = gv.settings().get_run_on() 

169 

170 # Check if Solver and instance sets were resolved 

171 check_for_initialise(COMMAND_DEPENDENCIES[CommandName.CONFIGURE_SOLVER]) 

172 

173 configurator = gv.settings().get_general_sparkle_configurator() 

174 configurator_settings = gv.settings().get_configurator_settings(configurator.name) 

175 if use_features and configurator.name == configurator_implementations.SMAC2.__name__: 

176 feature_data = FeatureDataFrame(gv.settings().DEFAULT_feature_data_path) 

177 

178 data_dict = {} 

179 feature_data_df = feature_data.dataframe 

180 

181 for label, row in feature_data_df.iterrows(): 

182 # os.path.split(os.path.split(label)[0])[1] gives the dir/instance set name 

183 if os.path.split(os.path.split(label)[0])[1] == instance_set_train.name: 

184 if row.empty: 

185 print("No feature data exists for the given training set, please " 

186 "run add_feature_extractor.py, then compute_features.py") 

187 sys.exit(-1) 

188 

189 new_label = (f"../../../instances/{instance_set_train.name}/" 

190 + os.path.split(label)[1]) 

191 data_dict[new_label] = row 

192 

193 feature_data_df = DataFrame.from_dict(data_dict, orient="index", 

194 columns=feature_data_df.columns) 

195 

196 if feature_data.has_missing_value(): 

197 print("You have unfinished feature computation jobs, please run " 

198 "`sparkle compute features`") 

199 sys.exit(-1) 

200 

201 for index, column in enumerate(feature_data_df): 

202 feature_data_df.rename(columns={column: f"Feature{index+1}"}, inplace=True) 

203 configurator_settings.update({"feature_data_df": feature_data_df}) 

204 

205 sparkle_objectives =\ 

206 gv.settings().get_general_sparkle_objectives() 

207 config_scenario = configurator.scenario_class( 

208 solver, instance_set_train, sparkle_objectives, configurator.output_path, 

209 **configurator_settings) 

210 

211 sbatch_options = gv.settings().get_slurm_extra_options(as_args=True) 

212 dependency_job_list = configurator.configure( 

213 scenario=config_scenario, 

214 sbatch_options=sbatch_options, 

215 num_parallel_jobs=gv.settings().get_number_of_jobs_in_parallel(), 

216 base_dir=sl.caller_log_dir, 

217 run_on=run_on) 

218 

219 # Update latest scenario 

220 gv.latest_scenario().set_config_solver(solver) 

221 gv.latest_scenario().set_config_instance_set_train(instance_set_train.directory) 

222 gv.latest_scenario().set_configuration_scenario(config_scenario.scenario_file_path) 

223 gv.latest_scenario().set_latest_scenario(Scenario.CONFIGURATION) 

224 

225 if instance_set_test is not None: 

226 gv.latest_scenario().set_config_instance_set_test(instance_set_test.directory) 

227 else: 

228 # Set to default to overwrite possible old path 

229 gv.latest_scenario().set_config_instance_set_test() 

230 

231 # Set validation to wait until configuration is done 

232 if validate: 

233 validate_jobid = run_after( 

234 solver, instance_set_train, instance_set_test, dependency_job_list, 

235 command=CommandName.VALIDATE_CONFIGURED_VS_DEFAULT, run_on=run_on 

236 ) 

237 dependency_job_list.append(validate_jobid) 

238 

239 if ablation: 

240 ablation_jobid = run_after( 

241 solver, instance_set_train, instance_set_test, dependency_job_list, 

242 command=CommandName.RUN_ABLATION, run_on=run_on 

243 ) 

244 dependency_job_list.append(ablation_jobid) 

245 

246 if run_on == Runner.SLURM: 

247 job_id_str = ",".join([run.run_id for run in dependency_job_list]) 

248 print(f"Running configuration. Waiting for Slurm job(s) with id(s): " 

249 f"{job_id_str}") 

250 else: 

251 print("Running configuration finished!") 

252 

253 # Write used settings to file 

254 gv.settings().write_used_settings() 

255 # Write used scenario to file 

256 gv.latest_scenario().write_scenario_ini() 

257 sys.exit(0) 

258 

259 

260if __name__ == "__main__": 

261 main(sys.argv[1:])