Coverage for sparkle/solver/solver.py: 84%
215 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 15:22 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 15:22 +0000
1"""File to handle a solver and its directories."""
2from __future__ import annotations
3import sys
4from typing import Any
5import shlex
6import ast
7import json
8from pathlib import Path
10from ConfigSpace import ConfigurationSpace
12import runrunner as rrr
13from runrunner.local import LocalRun
14from runrunner.slurm import Run, SlurmRun
15from runrunner.base import Status, Runner
17from sparkle.tools import pcsparser, RunSolver
18from sparkle.types import SparkleCallable, SolverStatus
19from sparkle.solver import verifiers
20from sparkle.instance import InstanceSet
21from sparkle.structures import PerformanceDataFrame
22from sparkle.types import resolve_objective, SparkleObjective, UseTime
25class Solver(SparkleCallable):
26 """Class to handle a solver and its directories."""
27 meta_data = "solver_meta.txt"
28 wrapper = "sparkle_solver_wrapper.py"
29 solver_cli = Path(__file__).parent / "solver_cli.py"
31 def __init__(self: Solver,
32 directory: Path,
33 raw_output_directory: Path = None,
34 runsolver_exec: Path = None,
35 deterministic: bool = None,
36 verifier: verifiers.SolutionVerifier = None) -> None:
37 """Initialize solver.
39 Args:
40 directory: Directory of the solver.
41 raw_output_directory: Directory where solver will write its raw output.
42 runsolver_exec: Path to the runsolver executable.
43 By default, runsolver in directory.
44 deterministic: Bool indicating determinism of the algorithm.
45 Defaults to False.
46 verifier: The solution verifier to use. If None, no verifier is used.
47 """
48 super().__init__(directory, runsolver_exec, raw_output_directory)
49 self.deterministic = deterministic
50 self.verifier = verifier
52 meta_data_file = self.directory / Solver.meta_data
53 if self.runsolver_exec is None:
54 self.runsolver_exec = self.directory / "runsolver"
55 if meta_data_file.exists():
56 meta_data = ast.literal_eval(meta_data_file.open().read())
57 # We only override the deterministic and verifier from file if not set
58 if self.deterministic is None:
59 if ("deterministic" in meta_data
60 and meta_data["deterministic"] is not None):
61 self.deterministic = meta_data["deterministic"]
62 if self.verifier is None and "verifier" in meta_data:
63 if isinstance(meta_data["verifier"], tuple): # File verifier
64 self.verifier = verifiers.mapping[meta_data["verifier"][0]](
65 Path(meta_data["verifier"][1])
66 )
67 elif meta_data["verifier"] in verifiers.mapping:
68 self.verifier = verifiers.mapping[meta_data["verifier"]]
69 if self.deterministic is None: # Default to False
70 self.deterministic = False
72 def __str__(self: Solver) -> str:
73 """Return the sting representation of the solver."""
74 return self.name
76 def _get_pcs_file(self: Solver, port_type: str = None) -> Path | bool:
77 """Get path of the parameter file.
79 Returns:
80 Path to the parameter file or False if the parameter file does not exist.
81 """
82 pcs_files = [p for p in self.directory.iterdir() if p.suffix == ".pcs"
83 and (port_type is None or port_type in p.name)]
85 if len(pcs_files) == 0:
86 return False
87 if len(pcs_files) != 1:
88 # Generated PCS files present, this is a quick fix to take the original
89 pcs_files = sorted(pcs_files, key=lambda p: len(p.name))
90 return pcs_files[0]
92 def get_pcs_file(self: Solver, port_type: str = None) -> Path:
93 """Get path of the parameter file.
95 Returns:
96 Path to the parameter file. None if it can not be resolved.
97 """
98 if not (file_path := self._get_pcs_file(port_type)):
99 return None
100 return file_path
102 def read_pcs_file(self: Solver) -> bool:
103 """Checks if the pcs file can be read."""
104 pcs_file = self._get_pcs_file()
105 try:
106 parser = pcsparser.PCSParser()
107 parser.load(str(pcs_file), convention="smac")
108 return True
109 except SyntaxError:
110 pass
111 return False
113 def get_pcs(self: Solver) -> dict[str, tuple[str, str, str]]:
114 """Get the parameter content of the PCS file."""
115 if not (pcs_file := self.get_pcs_file()):
116 return None
117 parser = pcsparser.PCSParser()
118 parser.load(str(pcs_file), convention="smac")
119 return [p for p in parser.pcs.params if p["type"] == "parameter"]
121 def port_pcs(self: Solver, port_type: pcsparser.PCSConvention) -> None:
122 """Port the parameter file to the given port type."""
123 pcs_file = self.get_pcs_file()
124 parser = pcsparser.PCSParser()
125 parser.load(str(pcs_file), convention="smac")
126 target_pcs_file = pcs_file.parent / f"{pcs_file.stem}_{port_type}.pcs"
127 if target_pcs_file.exists(): # Already exists, possibly user defined
128 return
129 parser.export(convention=port_type,
130 destination=target_pcs_file)
132 def get_configspace(self: Solver) -> ConfigurationSpace:
133 """Get the parameter content of the PCS file."""
134 if not (pcs_file := self.get_pcs_file()):
135 return None
136 parser = pcsparser.PCSParser()
137 parser.load(str(pcs_file), convention="smac")
138 return parser.get_configspace()
140 def get_forbidden(self: Solver, port_type: pcsparser.PCSConvention) -> Path:
141 """Get the path to the file containing forbidden parameter combinations."""
142 if port_type == "IRACE":
143 forbidden = [p for p in self.directory.iterdir()
144 if p.name.endswith("forbidden.txt")]
145 if len(forbidden) > 0:
146 return forbidden[0]
147 return None
149 def build_cmd(self: Solver,
150 instance: str | list[str],
151 objectives: list[SparkleObjective],
152 seed: int,
153 cutoff_time: int = None,
154 configuration: dict = None,
155 log_dir: Path = None) -> list[str]:
156 """Build the solver call on an instance with a configuration.
158 Args:
159 instance: Path to the instance.
160 seed: Seed of the solver.
161 cutoff_time: Cutoff time for the solver.
162 configuration: Configuration of the solver.
164 Returns:
165 List of commands and arguments to execute the solver.
166 """
167 if configuration is None:
168 configuration = {}
169 # Ensure configuration contains required entries for each wrapper
170 configuration["solver_dir"] = str(self.directory.absolute())
171 configuration["instance"] = instance
172 configuration["seed"] = seed
173 configuration["objectives"] = ",".join([str(obj) for obj in objectives])
174 configuration["cutoff_time"] =\
175 cutoff_time if cutoff_time is not None else sys.maxsize
176 if "configuration_id" in configuration:
177 del configuration["configuration_id"]
178 # Ensure stringification of dictionary will go correctly for key value pairs
179 configuration = {key: str(configuration[key]) for key in configuration}
180 solver_cmd = [str((self.directory / Solver.wrapper)),
181 f"'{json.dumps(configuration)}'"]
182 if log_dir is None:
183 log_dir = Path()
184 if cutoff_time is not None: # Use RunSolver
185 log_name_base = f"{Path(instance).name}_{self.name}"
186 return RunSolver.wrap_command(self.runsolver_exec,
187 solver_cmd,
188 cutoff_time,
189 log_dir,
190 log_name_base=log_name_base)
191 return solver_cmd
193 def run(self: Solver,
194 instances: str | list[str] | InstanceSet | list[InstanceSet],
195 objectives: list[SparkleObjective],
196 seed: int,
197 cutoff_time: int = None,
198 configuration: dict = None,
199 run_on: Runner = Runner.LOCAL,
200 sbatch_options: list[str] = None,
201 log_dir: Path = None,
202 ) -> SlurmRun | list[dict[str, Any]] | dict[str, Any]:
203 """Run the solver on an instance with a certain configuration.
205 Args:
206 instance: The instance(s) to run the solver on, list in case of multi-file.
207 In case of an instance set, will run on all instances in the set.
208 seed: Seed to run the solver with. Fill with abitrary int in case of
209 determnistic solver.
210 cutoff_time: The cutoff time for the solver, measured through RunSolver.
211 If None, will be executed without RunSolver.
212 configuration: The solver configuration to use. Can be empty.
213 log_dir: Path where to place output files. Defaults to
214 self.raw_output_directory.
216 Returns:
217 Solver output dict possibly with runsolver values.
218 """
219 if log_dir is None:
220 log_dir = self.raw_output_directory
221 cmds = []
222 instances = [instances] if not isinstance(instances, list) else instances
223 set_label = instances.name if isinstance(instances, InstanceSet) else "instances"
224 for instance in instances:
225 paths = instance.instace_paths if isinstance(instance,
226 InstanceSet) else [instance]
227 for instance_path in paths:
228 solver_cmd = self.build_cmd(instance_path,
229 objectives=objectives,
230 seed=seed,
231 cutoff_time=cutoff_time,
232 configuration=configuration,
233 log_dir=log_dir)
234 cmds.append(" ".join(solver_cmd))
236 commandname = f"Run Solver: {self.name} on {set_label}"
237 run = rrr.add_to_queue(runner=run_on,
238 cmd=cmds,
239 name=commandname,
240 base_dir=log_dir,
241 sbatch_options=sbatch_options)
243 if isinstance(run, LocalRun):
244 run.wait()
245 import time
246 time.sleep(5)
247 # Subprocess resulted in error
248 if run.status == Status.ERROR:
249 print(f"WARNING: Solver {self.name} execution seems to have failed!\n")
250 for i, job in enumerate(run.jobs):
251 print(f"[Job {i}] The used command was: {cmds[i]}\n"
252 "The error yielded was:\n"
253 f"\t-stdout: '{run.jobs[0]._process.stdout}'\n"
254 f"\t-stderr: '{run.jobs[0]._process.stderr}'\n")
255 return {"status": SolverStatus.ERROR, }
257 solver_outputs = []
258 for i, job in enumerate(run.jobs):
259 solver_cmd = cmds[i].split(" ")
260 solver_output = Solver.parse_solver_output(run.jobs[i].stdout,
261 solver_call=solver_cmd,
262 objectives=objectives,
263 verifier=self.verifier)
264 solver_outputs.append(solver_output)
265 return solver_outputs if len(solver_outputs) > 1 else solver_output
266 return run
268 def run_performance_dataframe(
269 self: Solver,
270 instances: str | list[str] | InstanceSet,
271 run_ids: int | list[int] | range[int, int]
272 | list[list[int]] | list[range[int]],
273 performance_dataframe: PerformanceDataFrame,
274 cutoff_time: int = None,
275 objective: SparkleObjective = None,
276 train_set: InstanceSet = None,
277 sbatch_options: list[str] = None,
278 dependencies: list[SlurmRun] = None,
279 log_dir: Path = None,
280 base_dir: Path = None,
281 job_name: str = None,
282 run_on: Runner = Runner.SLURM) -> Run:
283 """Run the solver from and place the results in the performance dataframe.
285 This in practice actually runs Solver.run, but has a little script before/after,
286 to read and write to the performance dataframe.
288 Args:
289 instance: The instance(s) to run the solver on. In case of an instance set,
290 or list, will create a job for all instances in the set/list.
291 run_ids: The run indices to use in the performance dataframe.
292 If int, will run only this id for all instances. If a list of integers
293 or range, will run all run indexes for all instances.
294 If a list of lists or list of ranges, will assume the runs are paired
295 with the instances, e.g. will use sequence 1 for instance 1, ...
296 performance_dataframe: The performance dataframe to use.
297 cutoff_time: The cutoff time for the solver, measured through RunSolver.
298 objective: The objective to use, only relevant for train set best config
299 determining
300 train_set: The training set to use. If present, will determine the best
301 configuration of the solver using these instances and run with it on
302 all instances in the instance argument.
303 sbatch_options: List of slurm batch options to use
304 dependencies: List of slurm runs to use as dependencies
305 log_dir: Path where to place output files. Defaults to
306 self.raw_output_directory.
307 base_dir: Path where to place output files.
308 job_name: Name of the job
309 If None, will generate a name based on Solver and Instances
310 run_on: On which platform to run the jobs. Default: Slurm.
312 Returns:
313 SlurmRun or Local run of the job.
314 """
315 instances = [instances] if isinstance(instances, str) else instances
316 set_name = "instances"
317 if isinstance(instances, InstanceSet):
318 set_name = instances.name
319 instances = [str(i) for i in instances.instance_paths]
320 # Resolve run_ids to which run indices to use for which instance
321 if isinstance(run_ids, int):
322 run_ids = [[run_ids]] * len(instances)
323 elif isinstance(run_ids, range):
324 run_ids = [list(run_ids)] * len(instances)
325 elif isinstance(run_ids, list):
326 if all(isinstance(i, int) for i in run_ids):
327 run_ids = [run_ids] * len(instances)
328 elif all(isinstance(i, range) for i in run_ids):
329 run_ids = [list(i) for i in run_ids]
330 elif all(isinstance(i, list) for i in run_ids):
331 pass
332 else:
333 raise TypeError(f"Invalid type combination for run_ids: {type(run_ids)}")
334 objective_arg = f"--target-objective {objective.name}" if objective else ""
335 train_arg =\
336 ",".join([str(i) for i in train_set.instance_paths]) if train_set else ""
337 cmds = [
338 f"python3 {Solver.solver_cli} "
339 f"--solver {self.directory} "
340 f"--instance {instance} "
341 f"--run-index {run_index} "
342 f"--performance-dataframe {performance_dataframe.csv_filepath} "
343 f"--cutoff-time {cutoff_time} "
344 f"--log-dir {log_dir} "
345 f"{objective_arg} "
346 f"{'--best-configuration-instances' if train_set else ''} {train_arg}"
347 for instance, run_indices in zip(instances, run_ids)
348 for run_index in run_indices]
349 job_name = f"Run: {self.name} on {set_name}" if job_name is None else job_name
350 r = rrr.add_to_queue(
351 runner=run_on,
352 cmd=cmds,
353 name=job_name,
354 base_dir=base_dir,
355 sbatch_options=sbatch_options,
356 dependencies=dependencies
357 )
358 if run_on == Runner.LOCAL:
359 r.wait()
360 return r
362 @staticmethod
363 def config_str_to_dict(config_str: str) -> dict[str, str]:
364 """Parse a configuration string to a dictionary."""
365 # First we filter the configuration of unwanted characters
366 config_str = config_str.strip().replace("-", "")
367 # Then we split the string by spaces, but conserve substrings
368 config_list = shlex.split(config_str)
369 # We return empty for empty input OR uneven input
370 if config_str == "" or config_str == r"{}" or len(config_list) & 1:
371 return {}
372 config_dict = {}
373 for index in range(0, len(config_list), 2):
374 # As the value will already be a string object, no quotes are allowed in it
375 value = config_list[index + 1].strip('"').strip("'")
376 config_dict[config_list[index]] = value
377 return config_dict
379 @staticmethod
380 def parse_solver_output(
381 solver_output: str,
382 solver_call: list[str | Path] = None,
383 objectives: list[SparkleObjective] = None,
384 verifier: verifiers.SolutionVerifier = None) -> dict[str, Any]:
385 """Parse the output of the solver.
387 Args:
388 solver_output: The output of the solver run which needs to be parsed
389 solver_call: The solver call used to run the solver
390 objectives: The objectives to apply to the solver output
391 verifier: The verifier to check the solver output
393 Returns:
394 Dictionary representing the parsed solver output
395 """
396 used_runsolver = False
397 if solver_call is not None and len(solver_call) > 2:
398 used_runsolver = True
399 parsed_output = RunSolver.get_solver_output(solver_call,
400 solver_output)
401 else:
402 parsed_output = ast.literal_eval(solver_output)
403 # cast status attribute from str to Enum
404 parsed_output["status"] = SolverStatus(parsed_output["status"])
405 # Apply objectives to parsed output, runtime based objectives added here
406 if verifier is not None and used_runsolver:
407 # Horrible hack to get the instance from the solver input
408 solver_call_str: str = " ".join(solver_call)
409 solver_input_str = solver_call_str.split(Solver.wrapper, maxsplit=1)[1]
410 solver_input_str = solver_input_str[solver_input_str.index("{"):
411 solver_input_str.index("}") + 1]
412 solver_input = ast.literal_eval(solver_input_str)
413 target_instance = Path(solver_input["instance"])
414 parsed_output["status"] = verifier.verify(
415 target_instance, parsed_output, solver_call)
417 # Create objective map
418 objectives = {o.stem: o for o in objectives} if objectives else {}
419 removable_keys = ["cutoff_time"] # Keys to remove
421 # apply objectives to parsed output, runtime based objectives added here
422 for key, value in parsed_output.items():
423 if objectives and key in objectives:
424 objective = objectives[key]
425 removable_keys.append(key) # We translate it into the full name
426 else:
427 objective = resolve_objective(key)
428 # If not found in objectives, resolve to which objective the output belongs
429 if objective is None: # Could not parse, skip
430 continue
431 if objective.use_time == UseTime.NO:
432 if objective.post_process is not None:
433 parsed_output[key] = objective.post_process(value)
434 else:
435 if not used_runsolver:
436 continue
437 if objective.use_time == UseTime.CPU_TIME:
438 parsed_output[key] = parsed_output["cpu_time"]
439 else:
440 parsed_output[key] = parsed_output["wall_time"]
441 if objective.post_process is not None:
442 parsed_output[key] = objective.post_process(
443 parsed_output[key],
444 parsed_output["cutoff_time"],
445 parsed_output["status"])
447 # Replace or remove keys based on the objective names
448 for key in removable_keys:
449 if key in parsed_output:
450 if key in objectives:
451 # Map the result to the objective
452 parsed_output[objectives[key].name] = parsed_output[key]
453 if key != objectives[key].name: # Only delete actual mappings
454 del parsed_output[key]
455 else:
456 del parsed_output[key]
457 return parsed_output