Coverage for sparkle/types/objective.py: 91%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-27 09:10 +0000

1"""Class for Sparkle Objective and Performance.""" 

2from __future__ import annotations 

3from enum import Enum 

4import typing 

5import numpy as np 

6 

7 

8class UseTime(str, Enum): 

9 """Use time or not.""" 

10 WALL_TIME = "WALL_TIME" 

11 CPU_TIME = "CPU_TIME" 

12 NO = "NO" 

13 

14 @classmethod 

15 def _missing_(cls: UseTime, value: object) -> UseTime: 

16 """Return error use time.""" 

17 return UseTime.NO 

18 

19 

20class SparkleObjective: 

21 """Objective for Sparkle specified by user.""" 

22 

23 name: str 

24 run_aggregator: typing.Callable 

25 instance_aggregator: typing.Callable 

26 solver_aggregator: typing.Callable 

27 minimise: bool 

28 post_process: typing.Callable 

29 use_time: UseTime 

30 

31 def __init__(self: SparkleObjective, 

32 name: str, 

33 run_aggregator: typing.Callable = np.mean, 

34 instance_aggregator: typing.Callable = np.mean, 

35 solver_aggregator: typing.Callable = None, 

36 minimise: bool = True, 

37 post_process: typing.Callable = None, 

38 use_time: UseTime = UseTime.NO) -> None: 

39 """Create sparkle objective from string.""" 

40 self.name = name 

41 self.run_aggregator: typing.Callable = run_aggregator 

42 self.instance_aggregator: typing.Callable = instance_aggregator 

43 if solver_aggregator is None: 

44 solver_aggregator = np.min if minimise else np.max 

45 self.solver_aggregator: typing.Callable = solver_aggregator 

46 self.minimise: bool = minimise 

47 self.post_process: typing.Callable = post_process 

48 self.use_time: UseTime = use_time 

49 

50 def __str__(self: SparkleObjective) -> str: 

51 """Return a stringified version.""" 

52 return f"{self.name}" 

53 

54 @property 

55 def time(self: SparkleObjective) -> bool: 

56 """Return whether the objective is time based.""" 

57 return self.use_time != UseTime.NO 

58 

59 

60class PAR(SparkleObjective): 

61 """Penalised Averaged Runtime Objective for Sparkle.""" 

62 

63 def __init__(self: PAR, k: int = 10) -> None: 

64 """Initialize PAR.""" 

65 self.k = k 

66 if k <= 0: 

67 raise ValueError("k must be greater than 0.") 

68 

69 def penalise(value: float, cutoff: float) -> float: 

70 """Return penalised value.""" 

71 if value > cutoff: 

72 return cutoff * self.k 

73 return value 

74 

75 super().__init__(f"PAR{k}", use_time=UseTime.CPU_TIME, post_process=penalise)