Coverage for sparkle/tools/solver_wrapper_parsing.py: 23%
30 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 14:48 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 14:48 +0000
1"""This module provides tools for the argument parsing for solver wrappers."""
2from pathlib import Path
3import ast
4from typing import Any
6from sparkle.types import resolve_objective
9def parse_commandline_dict(args: list[str]) -> dict:
10 """Parses a commandline dictionary to the object."""
11 dict_str = " ".join(args)
12 dict_str = dict_str[dict_str.index("{"):dict_str.index("}") + 1] # Slurm script fix
13 return ast.literal_eval(dict_str)
16def parse_solver_wrapper_args(args: list[str]) -> dict[Any]:
17 """Parse the arguments passed to the solver wrapper.
19 Args:
20 args: a list of arguments passed via the command line. It is ensured by Sparkle
21 that this list contains certain keys such as `solver_dir`.
23 Returns:
24 A dictionary mapping argument names to their currently held values.
25 """
26 args_dict = parse_commandline_dict(args)
28 # Some data needs specific formatting
29 args_dict["solver_dir"] = Path(args_dict["solver_dir"])
30 args_dict["instance"] = Path(args_dict["instance"])
31 args_dict["seed"] = int(args_dict["seed"])
32 args_dict["objectives"] = [resolve_objective(name)
33 for name in args_dict["objectives"].split(",")]
34 args_dict["cutoff_time"] = float(args_dict["cutoff_time"])
36 if "config_path" in args_dict:
37 # The arguments were not directly given and must be parsed from a file
38 config_str = Path(args_dict["config_path"]).open("r")\
39 .readlines()[args_dict["seed"]]
40 # Extract the args without any quotes
41 config_split = [arg.strip().replace("'", "").replace('"', "").strip("-")
42 for arg in config_str.split(" -") if arg.strip() != ""]
43 for arg in config_split:
44 varname, value = arg.strip("'").strip('"').split(" ", maxsplit=1)
45 args_dict[varname] = value
46 del args_dict["config_path"]
48 return args_dict
51def get_solver_call_params(args_dict: dict) -> list[str]:
52 """Gather the additional parameters for the solver call.
54 Args:
55 args_dict: Dictionary mapping argument names to their currently held values
57 Returns:
58 A list of parameters for the solver call
59 """
60 params = []
61 # Certain arguments are not relevant/have already been processed
62 ignore_args = {"solver_dir", "instance", "cutoff_time", "seed", "objectives"}
63 for key in args_dict:
64 if key not in ignore_args and args_dict[key] is not None:
65 params.extend(["-" + str(key), str(args_dict[key])])
67 return params