Coverage for sparkle/types/objective.py: 98%

51 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-29 10:17 +0000

1"""Class for Sparkle Objective and Performance.""" 

2 

3from __future__ import annotations 

4from enum import Enum 

5import typing 

6import numpy as np 

7from sparkle.types.status import SolverStatus 

8 

9 

10class UseTime(str, Enum): 

11 """Enum describing what type of time to use.""" 

12 

13 WALL_TIME = "WALL_TIME" 

14 CPU_TIME = "CPU_TIME" 

15 NO = "NO" 

16 

17 @classmethod 

18 def _missing_(cls: UseTime, value: object) -> UseTime: 

19 """Return error use time.""" 

20 return UseTime.NO 

21 

22 

23class SparkleObjective: 

24 """Objective for Sparkle specified by user.""" 

25 

26 name: str 

27 run_aggregator: typing.Callable 

28 instance_aggregator: typing.Callable 

29 solver_aggregator: typing.Callable 

30 minimise: bool 

31 post_process: typing.Callable 

32 use_time: UseTime 

33 metric: bool 

34 

35 def __init__( 

36 self: SparkleObjective, 

37 name: str, 

38 run_aggregator: typing.Callable = np.mean, 

39 instance_aggregator: typing.Callable = np.mean, 

40 solver_aggregator: typing.Callable = None, 

41 minimise: bool = True, 

42 post_process: typing.Callable = None, 

43 use_time: UseTime = UseTime.NO, 

44 metric: bool = False, 

45 ) -> None: 

46 """Create sparkle objective from string.""" 

47 self.name = name 

48 self.run_aggregator: typing.Callable = run_aggregator 

49 self.instance_aggregator: typing.Callable = instance_aggregator 

50 if solver_aggregator is None: 

51 solver_aggregator = np.min if minimise else np.max 

52 self.solver_aggregator: typing.Callable = solver_aggregator 

53 self.minimise: bool = minimise 

54 self.post_process: typing.Callable = post_process 

55 self.use_time: UseTime = use_time 

56 self.metric = metric 

57 

58 def __str__(self: SparkleObjective) -> str: 

59 """Return a stringified version.""" 

60 return self.name 

61 

62 @property 

63 def stem(self: SparkleObjective) -> str: 

64 """Return the stem of the objective name.""" 

65 return self.name.split(":")[0] 

66 

67 @property 

68 def time(self: SparkleObjective) -> bool: 

69 """Return whether the objective is time based.""" 

70 return self.use_time != UseTime.NO 

71 

72 

73class PAR(SparkleObjective): 

74 """Penalised Averaged Runtime Objective for Sparkle.""" 

75 

76 negative_status = { 

77 SolverStatus.CRASHED, 

78 SolverStatus.KILLED, 

79 SolverStatus.ERROR, 

80 SolverStatus.TIMEOUT, 

81 SolverStatus.WRONG, 

82 SolverStatus.UNKNOWN, 

83 } 

84 

85 def __init__( 

86 self: PAR, k: int = 10, minimise: bool = True, metric: bool = False 

87 ) -> None: 

88 """Initialize PAR.""" 

89 self.k = k 

90 if k <= 0: 

91 raise ValueError("k must be greater than 0.") 

92 

93 def penalise(value: float, cutoff: float, status: SolverStatus) -> float: 

94 """Return penalised value.""" 

95 if status in PAR.negative_status or value > cutoff: 

96 return cutoff * self.k 

97 return value 

98 

99 super().__init__( 

100 f"PAR{k}", 

101 minimise=minimise, 

102 use_time=UseTime.CPU_TIME, 

103 post_process=penalise, 

104 metric=metric, 

105 )