Coverage for sparkle/CLI/configure_solver.py: 76%
131 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-03 10:42 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-03 10:42 +0000
1#!/usr/bin/env python3
2"""Sparkle command to configure a solver."""
3from __future__ import annotations
4import argparse
5import sys
6import math
8from runrunner import Runner
10from sparkle.CLI.help import global_variables as gv
11from sparkle.CLI.help import logging as sl
12from sparkle.CLI.initialise import check_for_initialise
13from sparkle.CLI.help.reporting_scenario import Scenario
14from sparkle.CLI.help.nicknames import resolve_object_name
15from sparkle.CLI.help import argparse_custom as ac
17from sparkle.platform.settings_objects import SettingState
18from sparkle.structures import PerformanceDataFrame, FeatureDataFrame
19from sparkle.solver import Solver
20from sparkle.instance import Instance_Set
23def parser_function() -> argparse.ArgumentParser:
24 """Define the command line arguments."""
25 parser = argparse.ArgumentParser(
26 description="Configure a solver in the platform.",
27 epilog=("Note that the test instance set is only used if the ``--ablation``"
28 " or ``--validation`` flags are given"))
29 parser.add_argument(*ac.ConfiguratorArgument.names,
30 **ac.ConfiguratorArgument.kwargs)
31 parser.add_argument(*ac.SolverArgument.names,
32 **ac.SolverArgument.kwargs)
33 parser.add_argument(*ac.InstanceSetTrainArgument.names,
34 **ac.InstanceSetTrainArgument.kwargs)
35 parser.add_argument(*ac.InstanceSetTestArgument.names,
36 **ac.InstanceSetTestArgument.kwargs)
37 parser.add_argument(*ac.TestSetRunAllConfigurationArgument.names,
38 **ac.TestSetRunAllConfigurationArgument.kwargs)
39 parser.add_argument(*ac.ObjectivesArgument.names,
40 **ac.ObjectivesArgument.kwargs)
41 parser.add_argument(*ac.TargetCutOffTimeArgument.names,
42 **ac.TargetCutOffTimeArgument.kwargs)
43 parser.add_argument(*ac.SolverCallsArgument.names,
44 **ac.SolverCallsArgument.kwargs)
45 parser.add_argument(*ac.NumberOfRunsConfigurationArgument.names,
46 **ac.NumberOfRunsConfigurationArgument.kwargs)
47 parser.add_argument(*ac.SettingsFileArgument.names,
48 **ac.SettingsFileArgument.kwargs)
49 parser.add_argument(*ac.UseFeaturesArgument.names,
50 **ac.UseFeaturesArgument.kwargs)
51 parser.add_argument(*ac.RunOnArgument.names,
52 **ac.RunOnArgument.kwargs)
53 return parser
56def apply_settings_from_args(args: argparse.Namespace) -> None:
57 """Apply command line arguments to settings.
59 Args:
60 args: Arguments object created by ArgumentParser.
61 """
62 if args.settings_file is not None:
63 gv.settings().read_settings_ini(args.settings_file, SettingState.CMD_LINE)
64 if args.configurator is not None:
65 gv.settings().set_general_sparkle_configurator(
66 args.configurator, SettingState.CMD_LINE)
67 if args.objectives is not None:
68 gv.settings().set_general_sparkle_objectives(
69 args.objectives, SettingState.CMD_LINE)
70 if args.target_cutoff_time is not None:
71 gv.settings().set_general_target_cutoff_time(
72 args.target_cutoff_time, SettingState.CMD_LINE)
73 if args.solver_calls is not None:
74 gv.settings().set_configurator_solver_calls(
75 args.solver_calls, SettingState.CMD_LINE)
76 if args.number_of_runs is not None:
77 gv.settings().set_configurator_number_of_runs(
78 args.number_of_runs, SettingState.CMD_LINE)
79 if args.run_on is not None:
80 gv.settings().set_run_on(
81 args.run_on.value, SettingState.CMD_LINE)
84def main(argv: list[str]) -> None:
85 """Main function of the configure solver command."""
86 # Log command call
87 sl.log_command(sys.argv)
88 check_for_initialise()
90 parser = parser_function()
92 # Process command line arguments
93 args = parser.parse_args(argv)
95 apply_settings_from_args(args)
97 solver: Solver = resolve_object_name(
98 args.solver,
99 gv.file_storage_data_mapping[gv.solver_nickname_list_path],
100 gv.settings().DEFAULT_solver_dir, class_name=Solver)
101 if solver is None:
102 raise ValueError(f"Solver {args.solver} not found.")
103 instance_set_train = resolve_object_name(
104 args.instance_set_train,
105 gv.file_storage_data_mapping[gv.instances_nickname_path],
106 gv.settings().DEFAULT_instance_dir, Instance_Set)
107 if instance_set_train is None:
108 raise ValueError(f"Instance set {args.instance_set_train} not found.")
109 instance_set_test = args.instance_set_test
110 if instance_set_test is not None:
111 instance_set_test = resolve_object_name(
112 args.instance_set_test,
113 gv.file_storage_data_mapping[gv.instances_nickname_path],
114 gv.settings().DEFAULT_instance_dir, Instance_Set)
115 use_features = args.use_features
116 run_on = gv.settings().get_run_on()
118 configurator = gv.settings().get_general_sparkle_configurator()
119 configurator_settings = gv.settings().get_configurator_settings(configurator.name)
121 sparkle_objectives =\
122 gv.settings().get_general_sparkle_objectives()
123 if len(sparkle_objectives) > 1:
124 print(f"WARNING: {configurator.name} does not have multi objective support. "
125 f"Only the first objective ({sparkle_objectives[0]}) will be optimised.")
126 configurator_runs = gv.settings().get_configurator_number_of_runs()
127 performance_data = PerformanceDataFrame(gv.settings().DEFAULT_performance_data_path)
129 # Check if given objectives are in the data frame
130 for objective in sparkle_objectives:
131 if objective.name not in performance_data.objective_names:
132 print(f"WARNING: Objective {objective.name} not found in performance data. "
133 "Adding to data frame.")
134 performance_data.add_objective(objective.name)
136 if use_features:
137 feature_data = FeatureDataFrame(gv.settings().DEFAULT_feature_data_path)
138 # Check that the train instance set is in the feature data frame
139 invalid = False
140 remaining_instance_jobs =\
141 set([instance for instance, _, _ in feature_data.remaining_jobs()])
142 for instance in instance_set_train.instance_paths:
143 if str(instance) not in feature_data.instances:
144 print(f"ERROR: Train Instance {instance} not found in feature data.")
145 invalid = True
146 elif instance in remaining_instance_jobs: # Check jobs
147 print(f"ERROR: Features have not been computed for instance {instance}.")
148 invalid = True
149 if invalid:
150 sys.exit(-1)
151 configurator_settings.update({"feature_data": feature_data})
153 sbatch_options = gv.settings().get_slurm_extra_options(as_args=True)
154 slurm_prepend = gv.settings().get_slurm_job_prepend()
155 config_scenario = configurator.scenario_class()(
156 solver, instance_set_train, sparkle_objectives,
157 configurator.output_path, **configurator_settings)
159 # Run the default configuration
160 remaining_jobs = performance_data.get_job_list()
161 relevant_jobs = []
162 for instance, run_id, solver_id in remaining_jobs:
163 # NOTE: This run_id skip will not work if we do multiple runs per configuration
164 if run_id != 1 or solver_id != str(solver.directory):
165 continue
166 configuration = performance_data.get_value(
167 solver_id, instance, sparkle_objectives[0].name, run=run_id,
168 solver_fields=[PerformanceDataFrame.column_configuration])
169 # Only run jobs with the default configuration
170 if not isinstance(configuration, str) and math.isnan(configuration):
171 relevant_jobs.append((instance, run_id, solver_id))
173 # Expand the performance dataframe so it can store the configuration
174 performance_data.add_runs(configurator_runs,
175 instance_names=[
176 str(i) for i in instance_set_train.instance_paths],
177 initial_values=[PerformanceDataFrame.missing_value,
178 PerformanceDataFrame.missing_value,
179 {}])
180 if instance_set_test is not None:
181 # Expand the performance dataframe so it can store the test set results of the
182 # found configurations
183 test_set_runs = configurator_runs if args.test_set_run_all_configurations else 1
184 performance_data.add_runs(
185 test_set_runs,
186 instance_names=[str(i) for i in instance_set_test.instance_paths])
187 performance_data.save_csv()
189 dependency_job_list = configurator.configure(
190 scenario=config_scenario,
191 data_target=performance_data,
192 sbatch_options=sbatch_options,
193 slurm_prepend=slurm_prepend,
194 num_parallel_jobs=gv.settings().get_number_of_jobs_in_parallel(),
195 base_dir=sl.caller_log_dir,
196 run_on=run_on)
198 # If we have default configurations that need to be run, schedule them too
199 if len(relevant_jobs) > 0:
200 instances = [job[0] for job in relevant_jobs]
201 runs = list(set([job[1] for job in relevant_jobs]))
202 default_job = solver.run_performance_dataframe(
203 instances, runs, performance_data,
204 sbatch_options=sbatch_options,
205 slurm_prepend=slurm_prepend,
206 cutoff_time=config_scenario.cutoff_time,
207 log_dir=config_scenario.validation,
208 base_dir=sl.caller_log_dir,
209 job_name=f"Default Configuration: {solver.name} Validation on "
210 f"{instance_set_train.name}",
211 run_on=run_on)
212 dependency_job_list.append(default_job)
214 # Update latest scenario
215 gv.latest_scenario().set_config_solver(solver)
216 gv.latest_scenario().set_config_instance_set_train(instance_set_train.directory)
217 gv.latest_scenario().set_configuration_scenario(config_scenario.scenario_file_path)
218 gv.latest_scenario().set_latest_scenario(Scenario.CONFIGURATION)
220 if instance_set_test is not None:
221 gv.latest_scenario().set_config_instance_set_test(instance_set_test.directory)
222 # Schedule test set jobs
223 if args.test_set_run_all_configurations:
224 # TODO: Schedule test set runs for all configurations
225 print("Running all configurations on test set is not implemented yet.")
226 pass
227 else:
228 # We place the results in the index we just added
229 run_index = list(set([performance_data.get_instance_num_runs(str(i))
230 for i in instance_set_test.instance_paths]))
231 test_set_job = solver.run_performance_dataframe(
232 instance_set_test,
233 run_index,
234 performance_data,
235 cutoff_time=config_scenario.cutoff_time,
236 objective=config_scenario.sparkle_objective,
237 train_set=instance_set_train,
238 sbatch_options=sbatch_options,
239 log_dir=config_scenario.validation,
240 base_dir=sl.caller_log_dir,
241 dependencies=dependency_job_list,
242 job_name=f"Best Configuration: {solver.name} Validation on "
243 f"{instance_set_test.name}",
244 run_on=run_on)
245 dependency_job_list.append(test_set_job)
246 else:
247 # Set to default to overwrite possible old path
248 gv.latest_scenario().set_config_instance_set_test()
250 if run_on == Runner.SLURM:
251 job_id_str = ",".join([run.run_id for run in dependency_job_list])
252 print(f"Running configuration. Waiting for Slurm job(s) with id(s): "
253 f"{job_id_str}")
254 else:
255 print("Running configuration finished!")
257 # Write used settings to file
258 gv.settings().write_used_settings()
259 # Write used scenario to file
260 gv.latest_scenario().write_scenario_ini()
261 sys.exit(0)
264if __name__ == "__main__":
265 main(sys.argv[1:])