Coverage for sparkle/types/objective.py: 96%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-07 15:22 +0000

1"""Class for Sparkle Objective and Performance.""" 

2from __future__ import annotations 

3from enum import Enum 

4import typing 

5import numpy as np 

6from sparkle.types.status import SolverStatus 

7 

8 

9class UseTime(str, Enum): 

10 """Enum describing what type of time to use.""" 

11 WALL_TIME = "WALL_TIME" 

12 CPU_TIME = "CPU_TIME" 

13 NO = "NO" 

14 

15 @classmethod 

16 def _missing_(cls: UseTime, value: object) -> UseTime: 

17 """Return error use time.""" 

18 return UseTime.NO 

19 

20 

21class SparkleObjective: 

22 """Objective for Sparkle specified by user.""" 

23 

24 name: str 

25 run_aggregator: typing.Callable 

26 instance_aggregator: typing.Callable 

27 solver_aggregator: typing.Callable 

28 minimise: bool 

29 post_process: typing.Callable 

30 use_time: UseTime 

31 metric: bool 

32 

33 def __init__(self: SparkleObjective, 

34 name: str, 

35 run_aggregator: typing.Callable = np.mean, 

36 instance_aggregator: typing.Callable = np.mean, 

37 solver_aggregator: typing.Callable = None, 

38 minimise: bool = True, 

39 post_process: typing.Callable = None, 

40 use_time: UseTime = UseTime.NO, 

41 metric: bool = False) -> None: 

42 """Create sparkle objective from string.""" 

43 self.name = name 

44 self.run_aggregator: typing.Callable = run_aggregator 

45 self.instance_aggregator: typing.Callable = instance_aggregator 

46 if solver_aggregator is None: 

47 solver_aggregator = np.min if minimise else np.max 

48 self.solver_aggregator: typing.Callable = solver_aggregator 

49 self.minimise: bool = minimise 

50 self.post_process: typing.Callable = post_process 

51 self.use_time: UseTime = use_time 

52 self.metric = metric 

53 

54 def __str__(self: SparkleObjective) -> str: 

55 """Return a stringified version.""" 

56 return self.name 

57 

58 @property 

59 def stem(self: SparkleObjective) -> str: 

60 """Return the stem of the objective name.""" 

61 return self.name.split(":")[0] 

62 

63 @property 

64 def time(self: SparkleObjective) -> bool: 

65 """Return whether the objective is time based.""" 

66 return self.use_time != UseTime.NO 

67 

68 

69class PAR(SparkleObjective): 

70 """Penalised Averaged Runtime Objective for Sparkle.""" 

71 negative_status = {SolverStatus.CRASHED, 

72 SolverStatus.KILLED, 

73 SolverStatus.ERROR, 

74 SolverStatus.TIMEOUT, 

75 SolverStatus.WRONG} 

76 

77 def __init__(self: PAR, k: int = 10, 

78 minimise: bool = True, 

79 metric: bool = False) -> None: 

80 """Initialize PAR.""" 

81 self.k = k 

82 if k <= 0: 

83 raise ValueError("k must be greater than 0.") 

84 

85 def penalise(value: float, cutoff: float, status: SolverStatus) -> float: 

86 """Return penalised value.""" 

87 if status in PAR.negative_status or value > cutoff: 

88 return cutoff * self.k 

89 return value 

90 

91 super().__init__(f"PAR{k}", 

92 minimise=minimise, 

93 use_time=UseTime.CPU_TIME, 

94 post_process=penalise, 

95 metric=metric)