Coverage for sparkle/solver/ablation.py: 82%
112 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 14:48 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 14:48 +0000
1#!/usr/bin/env python3
2# -*- coding: UTF-8 -*-
3"""Helper functions for ablation analysis."""
4from __future__ import annotations
5import re
6import shutil
7import decimal
8from pathlib import Path
10import runrunner as rrr
11from runrunner.base import Runner, Run
13from sparkle.CLI.help import global_variables as gv
14from sparkle.CLI.help import logging as sl
16from sparkle.configurator.implementations import SMAC2
17from sparkle.platform import CommandName
18from sparkle.solver import Solver
19from sparkle.instance import InstanceSet
22class AblationScenario:
23 """Class for ablation analysis."""
24 def __init__(self: AblationScenario,
25 solver: Solver,
26 train_set: InstanceSet,
27 test_set: InstanceSet,
28 output_dir: Path,
29 ablation_executable: Path = None,
30 ablation_validation_executable: Path = None,
31 override_dirs: bool = False) -> None:
32 """Initialize ablation scenario.
34 Args:
35 solver: Solver object
36 train_set: The training instance
37 test_set: The test instance
38 output_dir: The output directory
39 ablation_executable: (Only for execution) The ablation executable
40 ablation_validation_executable: (Only for execution) The validation exec
41 override_dirs: Whether to clean the scenario directory if it already exists
42 """
43 self.ablation_exec = ablation_executable
44 self.ablation_validation_exec = ablation_validation_executable
45 self.solver = solver
46 self.train_set = train_set
47 self.test_set = test_set
48 self.output_dir = output_dir
49 self.scenario_name = f"{self.solver.name}_{self.train_set.name}"
50 if self.test_set is not None:
51 self.scenario_name += f"_{self.test_set.name}"
52 self.scenario_dir = self.output_dir / self.scenario_name
53 if override_dirs and self.scenario_dir.exists():
54 print("Warning: found existing ablation scenario. This will be removed.")
55 shutil.rmtree(self.scenario_dir)
57 # Create required scenario directories
58 self.tmp_dir = self.scenario_dir / "tmp"
59 self.tmp_dir.mkdir(parents=True, exist_ok=True)
61 self.validation_dir = self.scenario_dir / "validation"
62 self.validation_dir_tmp = self.validation_dir / "tmp"
63 self.validation_dir_tmp.mkdir(parents=True, exist_ok=True)
64 self.table_file = self.validation_dir / "log" / "ablation-validation-run1234.txt"
66 def create_configuration_file(self: AblationScenario) -> None:
67 """Create a configuration file for ablation analysis.
69 Args:
70 solver: Solver object
71 instance_train_name: The training instance
72 instance_test_name: The test instance
74 Returns:
75 None
76 """
77 ablation_scenario_dir = self.scenario_dir
78 objective = gv.settings().get_general_sparkle_objectives()[0]
79 configurator = gv.settings().get_general_sparkle_configurator()
80 config_scenario = gv.latest_scenario().get_configuration_scenario(
81 configurator.scenario_class)
82 _, opt_config_str = configurator.get_optimal_configuration(
83 config_scenario)
85 # We need to check which params are missing and supplement with default values
86 pcs = self.solver.get_pcs()
87 for p in pcs:
88 if p["name"] not in opt_config_str:
89 opt_config_str += f" -{p['name']} {p['default']}"
91 # Ablation cannot deal with E scientific notation in floats
92 ctx = decimal.Context(prec=16)
93 for config in opt_config_str.split(" -"):
94 _, value = config.strip().split(" ")
95 if "e" in value.lower():
96 value = value.strip("'")
97 float_value = float(value.lower())
98 formatted = format(ctx.create_decimal(float_value), "f")
99 opt_config_str = opt_config_str.replace(value, formatted)
101 smac_run_obj = SMAC2.get_smac_run_obj(objective)
102 objective_str = "MEAN10" if smac_run_obj == "RUNTIME" else "MEAN"
103 run_cutoff_time = gv.settings().get_general_target_cutoff_time()
104 run_cutoff_length = gv.settings().get_smac2_target_cutoff_length()
105 concurrent_clis = gv.settings().get_slurm_max_parallel_runs_per_node()
106 ablation_racing = gv.settings().get_ablation_racing_flag()
107 configurator = gv.settings().get_general_sparkle_configurator()
108 pcs_file_path = f"{self.solver.get_pcs_file().absolute()}" # Get Solver PCS
110 # Create config file
111 config_file = Path(f"{ablation_scenario_dir}/ablation_config.txt")
112 config = (f'algo = "{SMAC2.configurator_target.absolute()} '
113 f'{self.solver.directory.absolute()} {objective}"\n'
114 f"execdir = {self.tmp_dir.absolute()}\n"
115 "experimentDir = ./\n"
116 f"deterministic = {1 if self.solver.deterministic else 0}\n"
117 f"run_obj = {smac_run_obj}\n"
118 f"overall_obj = {objective_str}\n"
119 f"cutoffTime = {run_cutoff_time}\n"
120 f"cutoff_length = {run_cutoff_length}\n"
121 f"cli-cores = {concurrent_clis}\n"
122 f"useRacing = {ablation_racing}\n"
123 "seed = 1234\n"
124 f"paramfile = {pcs_file_path}\n"
125 "instance_file = instances_train.txt\n"
126 "test_instance_file = instances_test.txt\n"
127 "sourceConfiguration=DEFAULT\n"
128 f'targetConfiguration="{opt_config_str}"')
129 config_file.open("w").write(config)
130 # Write config to validation directory
131 conf_valid = config.replace(f"execdir = {self.tmp_dir.absolute()}\n",
132 f"execdir = {self.validation_dir_tmp.absolute()}\n")
133 (self.validation_dir / config_file.name).open("w").write(conf_valid)
135 def create_instance_file(self: AblationScenario, test: bool = False) -> None:
136 """Create an instance file for ablation analysis."""
137 file_suffix = "_train.txt"
138 instance_set = self.train_set
139 if test:
140 file_suffix = "_test.txt"
141 instance_set = self.test_set if self.test_set is not None else self.train_set
142 # We give the Ablation script the paths of the instances
143 file_instance = self.scenario_dir / f"instances{file_suffix}"
144 with file_instance.open("w") as fh:
145 for instance in instance_set._instance_paths:
146 # We need to unpack the multi instance file paths in quotes
147 if isinstance(instance, list):
148 joined_instances = " ".join(
149 [str(file.absolute()) for file in instance])
150 fh.write(f"{joined_instances}\n")
151 else:
152 fh.write(f"{instance.absolute()}\n")
153 # Copy to validation directory
154 shutil.copyfile(file_instance, self.validation_dir / file_instance.name)
156 def check_for_ablation(self: AblationScenario) -> bool:
157 """Checks if ablation has terminated successfully."""
158 if not self.table_file.is_file():
159 return False
160 # First line in the table file should be "Ablation analysis validation complete."
161 table_line = self.table_file.open().readline().strip()
162 return table_line == "Ablation analysis validation complete."
164 def read_ablation_table(self: AblationScenario) -> list[list[str]]:
165 """Read from ablation table of a scenario."""
166 if not self.check_for_ablation():
167 # No ablation table exists for this solver-instance pair
168 return []
169 results = [["Round", "Flipped parameter", "Source value", "Target value",
170 "Validation result"]]
172 for line in self.table_file.open().readlines():
173 # Pre-process lines from the ablation file and add to the results dictionary.
174 # Sometimes ablation rounds switch multiple parameters at once.
175 # EXAMPLE: 2 EDR, EDRalpha 0, 0.1 1, 0.1013241633106732 486.31691
176 # To split the row correctly, we remove the space before the comma separated
177 # parameters and add it back.
178 # T.S. 30-01-2024: the results object is a nested list not dictionary?
179 values = re.sub(r"\s+", " ", line.strip())
180 values = re.sub(r", ", ",", values)
181 values = [val.replace(",", ", ") for val in values.split(" ")]
182 if len(values) == 5:
183 results.append(values)
184 return results
186 def submit_ablation(self: AblationScenario,
187 run_on: Runner = Runner.SLURM) -> list[Run]:
188 """Submit an ablation job.
190 Args:
191 run_on: Determines to which RunRunner queue the job is added
193 Returns:
194 A list of Run objects. Empty when running locally.
195 """
196 # 1. submit the ablation to the runrunner queue
197 clis = gv.settings().get_slurm_max_parallel_runs_per_node()
198 cmd = f"{self.ablation_exec.absolute()} --optionFile ablation_config.txt"
199 srun_options = ["-N1", "-n1", f"-c{clis}"]
200 sbatch_options = [f"--cpus-per-task={clis}"] +\
201 gv.settings().get_slurm_extra_options(as_args=True)
203 run_ablation = rrr.add_to_queue(
204 runner=run_on,
205 cmd=cmd,
206 name=CommandName.RUN_ABLATION,
207 base_dir=sl.caller_log_dir,
208 path=self.scenario_dir,
209 sbatch_options=sbatch_options,
210 srun_options=srun_options)
212 runs = []
213 if run_on == Runner.LOCAL:
214 run_ablation.wait()
215 runs.append(run_ablation)
217 # 2. Run ablation validation run if we have a test set to run on
218 if self.test_set is not None:
219 # Validation dir should have a copy of all needed files, except for the
220 # output of the ablation run, which is stored in ablation-run[seed].txt
221 cmd = f"{self.ablation_validation_exec.absolute()} "\
222 "--optionFile ablation_config.txt "\
223 "--ablationLogFile ../log/ablation-run1234.txt"
225 run_ablation_validation = rrr.add_to_queue(
226 runner=run_on,
227 cmd=cmd,
228 name=CommandName.RUN_ABLATION_VALIDATION,
229 path=self.validation_dir,
230 base_dir=sl.caller_log_dir,
231 dependencies=run_ablation,
232 sbatch_options=sbatch_options,
233 srun_options=srun_options)
235 if run_on == Runner.LOCAL:
236 run_ablation_validation.wait()
237 runs.append(run_ablation_validation)
239 return runs