Coverage for sparkle/tools/solver_wrapper_parsing.py: 23%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 14:48 +0000

1"""This module provides tools for the argument parsing for solver wrappers.""" 

2from pathlib import Path 

3import ast 

4from typing import Any 

5 

6from sparkle.types import resolve_objective 

7 

8 

9def parse_commandline_dict(args: list[str]) -> dict: 

10 """Parses a commandline dictionary to the object.""" 

11 dict_str = " ".join(args) 

12 dict_str = dict_str[dict_str.index("{"):dict_str.index("}") + 1] # Slurm script fix 

13 return ast.literal_eval(dict_str) 

14 

15 

16def parse_solver_wrapper_args(args: list[str]) -> dict[Any]: 

17 """Parse the arguments passed to the solver wrapper. 

18 

19 Args: 

20 args: a list of arguments passed via the command line. It is ensured by Sparkle 

21 that this list contains certain keys such as `solver_dir`. 

22 

23 Returns: 

24 A dictionary mapping argument names to their currently held values. 

25 """ 

26 args_dict = parse_commandline_dict(args) 

27 

28 # Some data needs specific formatting 

29 args_dict["solver_dir"] = Path(args_dict["solver_dir"]) 

30 args_dict["instance"] = Path(args_dict["instance"]) 

31 args_dict["seed"] = int(args_dict["seed"]) 

32 args_dict["objectives"] = [resolve_objective(name) 

33 for name in args_dict["objectives"].split(",")] 

34 args_dict["cutoff_time"] = float(args_dict["cutoff_time"]) 

35 

36 if "config_path" in args_dict: 

37 # The arguments were not directly given and must be parsed from a file 

38 config_str = Path(args_dict["config_path"]).open("r")\ 

39 .readlines()[args_dict["seed"]] 

40 # Extract the args without any quotes 

41 config_split = [arg.strip().replace("'", "").replace('"', "").strip("-") 

42 for arg in config_str.split(" -") if arg.strip() != ""] 

43 for arg in config_split: 

44 varname, value = arg.strip("'").strip('"').split(" ", maxsplit=1) 

45 args_dict[varname] = value 

46 del args_dict["config_path"] 

47 

48 return args_dict 

49 

50 

51def get_solver_call_params(args_dict: dict) -> list[str]: 

52 """Gather the additional parameters for the solver call. 

53 

54 Args: 

55 args_dict: Dictionary mapping argument names to their currently held values 

56 

57 Returns: 

58 A list of parameters for the solver call 

59 """ 

60 params = [] 

61 # Certain arguments are not relevant/have already been processed 

62 ignore_args = {"solver_dir", "instance", "cutoff_time", "seed", "objectives"} 

63 for key in args_dict: 

64 if key not in ignore_args and args_dict[key] is not None: 

65 params.extend(["-" + str(key), str(args_dict[key])]) 

66 

67 return params