Coverage for src / sparkle / solver / solver.py: 91%
234 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 15:31 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 15:31 +0000
1"""File to handle a solver and its directories."""
3from __future__ import annotations
4import sys
5from typing import Any
6import shlex
7import ast
8import json
9import random
10from pathlib import Path
12from ConfigSpace import ConfigurationSpace
14import runrunner as rrr
15from runrunner.local import LocalRun
16from runrunner.slurm import Run, SlurmRun
17from runrunner.base import Status, Runner
19from sparkle.tools.parameters import PCSConverter, PCSConvention
20from sparkle.tools import RunSolver
21from sparkle.types import SparkleCallable, SolverStatus
22from sparkle.solver import verifiers
23from sparkle.instance import InstanceSet
24from sparkle.structures import PerformanceDataFrame
25from sparkle.types import resolve_objective, SparkleObjective, UseTime
28class Solver(SparkleCallable):
29 """Class to handle a solver and its directories."""
31 meta_data = "solver_meta.txt"
32 _wrapper_file = "sparkle_solver_wrapper"
33 solver_cli = Path(__file__).parent / "solver_cli.py"
35 def __init__(
36 self: Solver,
37 directory: Path,
38 runsolver_exec: Path = None,
39 deterministic: bool = None,
40 verifier: verifiers.SolutionVerifier = None,
41 ) -> None:
42 """Initialize solver.
44 Args:
45 directory: Directory of the solver.
46 runsolver_exec: Path to the runsolver executable.
47 By default, runsolver in directory.
48 deterministic: Bool indicating determinism of the algorithm.
49 Defaults to False.
50 verifier: The solution verifier to use. If None, no verifier is used.
51 """
52 super().__init__(directory, runsolver_exec)
53 self.deterministic = deterministic
54 self.verifier = verifier
55 self._pcs_file: Path = None
56 self._interpreter: str = None
57 self._wrapper_extension: str = None
59 meta_data_file = self.directory / Solver.meta_data
60 if meta_data_file.exists():
61 meta_data = ast.literal_eval(meta_data_file.open().read())
62 # We only override the deterministic and verifier from file if not set
63 if self.deterministic is None:
64 if (
65 "deterministic" in meta_data
66 and meta_data["deterministic"] is not None
67 ):
68 self.deterministic = meta_data["deterministic"]
69 if self.verifier is None and "verifier" in meta_data:
70 if isinstance(meta_data["verifier"], tuple): # File verifier
71 self.verifier = verifiers.mapping[meta_data["verifier"][0]](
72 Path(meta_data["verifier"][1])
73 )
74 elif meta_data["verifier"] in verifiers.mapping:
75 self.verifier = verifiers.mapping[meta_data["verifier"]]
76 if self.deterministic is None: # Default to False
77 self.deterministic = False
79 def __str__(self: Solver) -> str:
80 """Return the string representation of the solver."""
81 return self.name
83 def __repr__(self: Solver) -> str:
84 """Return detailed representation of the solver."""
85 return (
86 f"{self.name}:\n"
87 f"\t- Directory: {self.directory}\n"
88 f"\t- Deterministic: {self.deterministic}\n"
89 f"\t- Verifier: {self.verifier}\n"
90 f"\t- PCS File: {self.pcs_file}\n"
91 f"\t- Wrapper: {self.wrapper}"
92 )
94 def __eq__(self: Solver, other: Any) -> bool:
95 """Checks whether two solvers are equal."""
96 if isinstance(other, Solver):
97 return other.directory == self.directory
98 elif isinstance(other, str):
99 return other == self.name or Path(other) == self.directory
100 elif isinstance(other, Path):
101 return other == self.directory
102 return False
104 def __hash__(self: Solver) -> int:
105 """Pass to parent class hash function. Should be inherited but does not work without this."""
106 return super().__hash__()
108 @property
109 def pcs_file(self: Solver) -> Path:
110 """Get path of the parameter file."""
111 if self._pcs_file is None:
112 for file in self.directory.iterdir():
113 if file.name == Solver.meta_data:
114 continue # Skip this file, never correct
115 convention = PCSConverter.get_convention(file)
116 if convention != PCSConvention.UNKNOWN:
117 self._pcs_file = file
118 return self._pcs_file
119 return self._pcs_file
121 def get_pcs_file_type(self: Solver, convention: PCSConvention) -> Path:
122 """Get path of the parameter file of a specific convention."""
123 for file in self.directory.iterdir():
124 if file.name == Solver.meta_data:
125 continue # Skip this file, never correct
126 if PCSConverter.get_convention(file) == convention:
127 return file
128 return None
130 @property
131 def wrapper_extension(self: Solver) -> str:
132 """Get the extension of the wrapper file."""
133 if self._wrapper_extension is None:
134 # Determine which file is the wrapper by sorting alphabetically
135 wrapper = sorted(
136 [p for p in self.directory.iterdir() if p.stem == Solver._wrapper_file]
137 )[0]
138 self._wrapper_extension = wrapper.suffix
139 return self._wrapper_extension
141 @property
142 def wrapper(self: Solver) -> str:
143 """Get name of the wrapper file."""
144 return f"{Solver._wrapper_file}{self.wrapper_extension}"
146 @property
147 def wrapper_file(self: Solver) -> Path:
148 """Get path of the wrapper file."""
149 return self.directory / self.wrapper
151 def get_pcs_file(self: Solver, port_type: PCSConvention) -> Path:
152 """Get path of the parameter file of a specific convention.
154 Args:
155 port_type: Port type of the parameter file. If None, will return the
156 file with the shortest name.
158 Returns:
159 Path to the parameter file. None if it can not be resolved.
160 """
161 pcs_files = sorted([p for p in self.directory.iterdir() if p.suffix == ".pcs"])
162 if port_type is None:
163 return pcs_files[0]
164 for file in pcs_files:
165 if port_type == PCSConverter.get_convention(file):
166 return file
167 return None
169 def read_pcs_file(self: Solver) -> bool:
170 """Checks if the pcs file can be read."""
171 # TODO: Should be a .validate method instead
172 return PCSConverter.get_convention(self.pcs_file) is not None
174 def get_configuration_space(self: Solver) -> ConfigurationSpace:
175 """Get the ConfigurationSpace of the PCS file."""
176 if not self.pcs_file:
177 return None
178 return PCSConverter.parse(self.pcs_file)
180 def port_pcs(self: Solver, port_type: PCSConvention) -> None:
181 """Port the parameter file to the given port type."""
182 target_pcs_file = (
183 self.pcs_file.parent / f"{self.pcs_file.stem}_{port_type.name}.pcs"
184 )
185 if target_pcs_file.exists(): # Already exists, possibly user defined
186 return
187 PCSConverter.export(self.get_configuration_space(), port_type, target_pcs_file)
189 def build_cmd(
190 self: Solver,
191 instance: str | list[str],
192 objectives: list[SparkleObjective],
193 seed: int,
194 cutoff_time: int = None,
195 configuration: dict = None,
196 log_dir: Path = None,
197 ) -> list[str]:
198 """Build the solver call on an instance with a configuration.
200 Args:
201 instance: Path to the instance.
202 objectives: List of sparkle objectives.
203 seed: Seed of the solver.
204 cutoff_time: Cutoff time for the solver.
205 configuration: Configuration of the solver.
206 log_dir: Directory path for logs.
208 Returns:
209 List of commands and arguments to execute the solver.
210 """
211 if configuration is None:
212 configuration = {}
213 # Ensure configuration contains required entries for each wrapper
214 configuration["solver_dir"] = str(self.directory.absolute())
215 configuration["instance"] = instance
216 configuration["seed"] = seed
217 configuration["objectives"] = ",".join([str(obj) for obj in objectives])
218 configuration["cutoff_time"] = (
219 cutoff_time if cutoff_time is not None else sys.maxsize
220 )
221 if "configuration_id" in configuration:
222 del configuration["configuration_id"]
223 # Ensure stringification of dictionary will go correctly for key value pairs
224 configuration = {key: str(configuration[key]) for key in configuration}
225 solver_cmd = [
226 str(self.directory / self.wrapper),
227 f"'{json.dumps(configuration)}'",
228 ]
229 if log_dir is None:
230 log_dir = Path()
231 if cutoff_time is not None: # Use RunSolver
232 log_path_str = instance[0] if isinstance(instance, list) else instance
233 log_name_base = f"{Path(log_path_str).name}_{self.name}"
234 return RunSolver.wrap_command(
235 self.runsolver_exec,
236 solver_cmd,
237 cutoff_time,
238 log_dir,
239 log_name_base=log_name_base,
240 )
241 return solver_cmd
243 def run(
244 self: Solver,
245 instances: str | list[str] | InstanceSet | list[InstanceSet],
246 objectives: list[SparkleObjective],
247 seed: int,
248 cutoff_time: int = None,
249 configuration: dict = None,
250 run_on: Runner = Runner.LOCAL,
251 sbatch_options: list[str] = None,
252 slurm_prepend: str | list[str] | Path = None,
253 log_dir: Path = None,
254 ) -> SlurmRun | list[dict[str, Any]] | dict[str, Any]:
255 """Run the solver on an instance with a certain configuration.
257 Args:
258 instances: The instance(s) to run the solver on, list in case of multi-file.
259 In case of an instance set, will run on all instances in the set.
260 objectives: List of sparkle objectives.
261 seed: Seed to run the solver with. Fill with abitrary int in case of
262 determnistic solver.
263 cutoff_time: The cutoff time for the solver, measured through RunSolver.
264 If None, will be executed without RunSolver.
265 configuration: The solver configuration to use. Can be empty.
266 run_on: Whether to run on slurm or locally.
267 sbatch_options: The sbatch options to use.
268 slurm_prepend: The script to prepend to a slurm script.
269 log_dir: The log directory to use.
271 Returns:
272 Solver output dict possibly with runsolver values.
273 """
274 cmds = []
275 set_label = instances.name if isinstance(instances, InstanceSet) else "instances"
276 instances = [instances] if not isinstance(instances, list) else instances
277 log_dir = Path() if log_dir is None else log_dir
279 for instance in instances:
280 paths = (
281 instance.instance_paths
282 if isinstance(instance, InstanceSet)
283 else [instance]
284 )
285 for instance_path in paths:
286 instance_path = (
287 [str(p) for p in instance_path]
288 if isinstance(instance_path, list)
289 else instance_path
290 )
291 solver_cmd = self.build_cmd(
292 instance_path,
293 objectives=objectives,
294 seed=seed,
295 cutoff_time=cutoff_time,
296 configuration=configuration,
297 log_dir=log_dir,
298 )
299 cmds.append(" ".join(solver_cmd))
301 commandname = f"Run Solver {self.name} on {set_label}"
302 run = rrr.add_to_queue(
303 runner=run_on,
304 cmd=cmds,
305 name=commandname,
306 base_dir=log_dir,
307 sbatch_options=sbatch_options,
308 prepend=slurm_prepend,
309 )
311 if isinstance(run, LocalRun):
312 run.wait()
313 if run.status == Status.ERROR: # Subprocess resulted in error
314 print(f"WARNING: Solver {self.name} execution seems to have failed!\n")
315 for i, job in enumerate(run.jobs):
316 print(
317 f"[Job {i}] The used command was: {cmds[i]}\n"
318 "The error yielded was:\n"
319 f"\t-stdout: '{job.stdout}'\n"
320 f"\t-stderr: '{job.stderr}'\n"
321 )
322 return {
323 "status": SolverStatus.ERROR,
324 }
326 solver_outputs = []
327 for i, job in enumerate(run.jobs):
328 solver_cmd = cmds[i].split(" ")
329 solver_output = Solver.parse_solver_output(
330 run.jobs[i].stdout,
331 solver_call=solver_cmd,
332 objectives=objectives,
333 verifier=self.verifier,
334 )
335 solver_outputs.append(solver_output)
336 return solver_outputs if len(solver_outputs) > 1 else solver_output
337 return run
339 def run_performance_dataframe(
340 self: Solver,
341 instances: str | list[str] | InstanceSet,
342 performance_dataframe: PerformanceDataFrame,
343 config_ids: str | list[str] = None,
344 run_ids: list[int] | list[list[int]] = None,
345 cutoff_time: int = None,
346 objective: SparkleObjective = None,
347 train_set: InstanceSet = None,
348 sbatch_options: list[str] = None,
349 slurm_prepend: str | list[str] | Path = None,
350 dependencies: list[SlurmRun] = None,
351 log_dir: Path = None,
352 base_dir: Path = None,
353 job_name: str = None,
354 run_on: Runner = Runner.SLURM,
355 ) -> Run:
356 """Run the solver from and place the results in the performance dataframe.
358 This in practice actually runs Solver.run, but has a little script before/after,
359 to read and write to the performance dataframe.
361 Args:
362 instances: The instance(s) to run the solver on. In case of an instance set,
363 or list, will create a job for all instances in the set/list.
364 config_ids: The config indices to use in the performance dataframe.
365 performance_dataframe: The performance dataframe to use.
366 run_ids: List of run ids to use. If list of list, a list of runs is given
367 per instance. Otherwise, all runs are used for each instance.
368 cutoff_time: The cutoff time for the solver, measured through RunSolver.
369 objective: The objective to use, only relevant when determining the best
370 configuration.
371 train_set: The training set to use. If present, will determine the best
372 configuration of the solver using these instances and run with it on
373 all instances in the instance argument.
374 sbatch_options: List of slurm batch options to use
375 slurm_prepend: Slurm script to prepend to the sbatch
376 dependencies: List of slurm runs to use as dependencies
377 log_dir: Path where to place output files. Defaults to CWD.
378 base_dir: Path where to place output files.
379 job_name: Name of the job
380 If None, will generate a name based on Solver and Instances
381 run_on: On which platform to run the jobs. Default: Slurm.
383 Returns:
384 SlurmRun or Local run of the job.
385 """
386 instances = [instances] if isinstance(instances, str) else instances
387 set_name = "instances"
388 if isinstance(instances, InstanceSet):
389 set_name = instances.name
390 instances = [str(i) for i in instances.instance_paths]
391 if not isinstance(config_ids, list):
392 config_ids = [config_ids]
393 configurations = [
394 performance_dataframe.get_full_configuration(str(self.directory), config_id)
395 if config_id
396 else None
397 for config_id in config_ids
398 ]
399 if run_ids is None:
400 run_ids = performance_dataframe.run_ids
401 if isinstance(run_ids[0], list): # Runs per instance
402 combinations = []
403 for index, instance in enumerate(instances):
404 for run_id in run_ids[index]:
405 combinations.extend(
406 [
407 (instance, config_id, config, run_id)
408 for config_id, config in zip(config_ids, configurations)
409 ]
410 )
411 else: # Runs for all instances
412 import itertools
414 combinations = [
415 (instance, config_data[0], config_data[1], run_id)
416 for instance, config_data, run_id in itertools.product(
417 instances,
418 zip(config_ids, configurations),
419 performance_dataframe.run_ids,
420 )
421 ]
422 objective_arg = f"--target-objective {objective.name}" if objective else ""
423 train_arg = (
424 "--best-configuration-instances "
425 + " ".join([str(i) for i in train_set.instance_paths])
426 if train_set
427 else ""
428 )
429 configuration_args = [
430 ""
431 if not config_id and not config
432 else f"--configuration-id {config_id}"
433 if not config
434 else f"--configuration '{json.dumps(config)}'"
435 for _, config_id, config, _ in combinations
436 ]
438 # We run all instances/configs/runs combinations
439 # For each value we try to resolve from the PDF, to avoid high read loads during executions
440 cmds = [
441 f"python3 {Solver.solver_cli} "
442 f"--solver {self.directory} "
443 f"--instance {instance} "
444 f"{config_arg} "
445 # f"{'--configuration-id ' + config_id if not config else '--configuration"' + str(config) + '\"'} "
446 f"--run-index {run_id} "
447 f"--objectives {' '.join([obj.name for obj in performance_dataframe.objectives])} "
448 f"--performance-dataframe {performance_dataframe.csv_filepath} "
449 f"--cutoff-time {cutoff_time} "
450 f"--log-dir {log_dir} "
451 f"--seed {random.randint(0, 2**32 - 1)} "
452 f"{objective_arg} "
453 f"{train_arg}"
454 for (instance, _, _, run_id), config_arg in zip(
455 combinations, configuration_args
456 )
457 ]
458 job_name = f"Run {self.name} on {set_name}" if job_name is None else job_name
459 r = rrr.add_to_queue(
460 runner=run_on,
461 cmd=cmds,
462 name=job_name,
463 base_dir=base_dir,
464 sbatch_options=sbatch_options,
465 prepend=slurm_prepend,
466 dependencies=dependencies,
467 )
468 if run_on == Runner.LOCAL:
469 r.wait()
470 return r
472 @staticmethod
473 def config_str_to_dict(config_str: str) -> dict[str, str]:
474 """Parse a configuration string to a dictionary."""
475 # First we filter the configuration of unwanted characters
476 config_str = config_str.strip().replace("-", "")
477 # Then we split the string by spaces, but conserve substrings
478 config_list = shlex.split(config_str)
479 # We return empty for empty input OR uneven input
480 if config_str == "" or config_str == r"{}" or len(config_list) & 1:
481 return {}
482 config_dict = {}
483 for index in range(0, len(config_list), 2):
484 # As the value will already be a string object, no quotes are allowed in it
485 value = config_list[index + 1].strip('"').strip("'")
486 config_dict[config_list[index]] = value
487 return config_dict
489 @staticmethod
490 def parse_solver_output(
491 solver_output: str,
492 solver_call: list[str | Path] = None,
493 objectives: list[SparkleObjective] = None,
494 verifier: verifiers.SolutionVerifier = None,
495 ) -> dict[str, Any]:
496 """Parse the output of the solver.
498 Args:
499 solver_output: The output of the solver run which needs to be parsed
500 solver_call: The solver call used to run the solver
501 objectives: The objectives to apply to the solver output
502 verifier: The verifier to check the solver output
504 Returns:
505 Dictionary representing the parsed solver output
506 """
507 used_runsolver = False
508 if (
509 solver_call is not None
510 and len(solver_call) > 2
511 and solver_call[0].endswith("runsolver")
512 or solver_call[1].endswith("py_runsolver.py")
513 ):
514 used_runsolver = True # PyRunsolver or RunSolver was used
515 parsed_output = RunSolver.get_solver_output(solver_call, solver_output)
516 else:
517 parsed_output = ast.literal_eval(solver_output)
518 # cast status attribute from str to Enum
519 parsed_output["status"] = SolverStatus(parsed_output["status"])
520 # Apply objectives to parsed output, runtime based objectives added here
521 if verifier is not None and used_runsolver:
522 # Horrible hack to get the instance from the solver input
523 solver_call_str: str = " ".join(solver_call)
524 solver_input_str = solver_call_str.split(Solver._wrapper_file, maxsplit=1)[1]
525 solver_input_str = solver_input_str.split(" ", maxsplit=1)[1]
526 solver_input_str = solver_input_str[
527 solver_input_str.index("{") : solver_input_str.index("}") + 1
528 ]
529 solver_input = ast.literal_eval(solver_input_str)
530 target_instance = Path(solver_input["instance"])
531 parsed_output["status"] = verifier.verify(
532 target_instance, parsed_output, solver_call
533 )
535 # Create objective map
536 objectives = {o.stem: o for o in objectives} if objectives else {}
537 removable_keys = ["cutoff_time"] # Keys to remove
539 # apply objectives to parsed output, runtime based objectives added here
540 for key, value in parsed_output.items():
541 if objectives and key in objectives:
542 objective = objectives[key]
543 removable_keys.append(key) # We translate it into the full name
544 else:
545 objective = resolve_objective(key)
546 # If not found in objectives, resolve to which objective the output belongs
547 if objective is None: # Could not parse, skip
548 continue
549 if objective.use_time == UseTime.NO:
550 if objective.post_process is not None:
551 parsed_output[key] = objective.post_process(value)
552 else:
553 if not used_runsolver:
554 continue
555 if objective.use_time == UseTime.CPU_TIME:
556 parsed_output[key] = parsed_output["cpu_time"]
557 else:
558 parsed_output[key] = parsed_output["wall_time"]
559 if objective.post_process is not None:
560 parsed_output[key] = objective.post_process(
561 parsed_output[key],
562 parsed_output["cutoff_time"],
563 parsed_output["status"],
564 )
566 # Replace or remove keys based on the objective names
567 for key in removable_keys:
568 if key in parsed_output:
569 if key in objectives:
570 # Map the result to the objective
571 parsed_output[objectives[key].name] = parsed_output[key]
572 if key != objectives[key].name: # Only delete actual mappings
573 del parsed_output[key]
574 else:
575 del parsed_output[key]
576 return parsed_output