Coverage for sparkle/types/objective.py: 98%
44 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 14:48 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 14:48 +0000
1"""Class for Sparkle Objective and Performance."""
2from __future__ import annotations
3from enum import Enum
4import typing
5import numpy as np
8class UseTime(str, Enum):
9 """Enum describing what type of time to use."""
10 WALL_TIME = "WALL_TIME"
11 CPU_TIME = "CPU_TIME"
12 NO = "NO"
14 @classmethod
15 def _missing_(cls: UseTime, value: object) -> UseTime:
16 """Return error use time."""
17 return UseTime.NO
20class SparkleObjective:
21 """Objective for Sparkle specified by user."""
23 name: str
24 run_aggregator: typing.Callable
25 instance_aggregator: typing.Callable
26 solver_aggregator: typing.Callable
27 minimise: bool
28 post_process: typing.Callable
29 use_time: UseTime
31 def __init__(self: SparkleObjective,
32 name: str,
33 run_aggregator: typing.Callable = np.mean,
34 instance_aggregator: typing.Callable = np.mean,
35 solver_aggregator: typing.Callable = None,
36 minimise: bool = True,
37 post_process: typing.Callable = None,
38 use_time: UseTime = UseTime.NO) -> None:
39 """Create sparkle objective from string."""
40 self.name = name
41 self.run_aggregator: typing.Callable = run_aggregator
42 self.instance_aggregator: typing.Callable = instance_aggregator
43 if solver_aggregator is None:
44 solver_aggregator = np.min if minimise else np.max
45 self.solver_aggregator: typing.Callable = solver_aggregator
46 self.minimise: bool = minimise
47 self.post_process: typing.Callable = post_process
48 self.use_time: UseTime = use_time
50 def __str__(self: SparkleObjective) -> str:
51 """Return a stringified version."""
52 return f"{self.name}"
54 @property
55 def time(self: SparkleObjective) -> bool:
56 """Return whether the objective is time based."""
57 return self.use_time != UseTime.NO
60class PAR(SparkleObjective):
61 """Penalised Averaged Runtime Objective for Sparkle."""
63 def __init__(self: PAR, k: int = 10) -> None:
64 """Initialize PAR."""
65 self.k = k
66 if k <= 0:
67 raise ValueError("k must be greater than 0.")
69 def penalise(value: float, cutoff: float) -> float:
70 """Return penalised value."""
71 if value > cutoff:
72 return cutoff * self.k
73 return value
75 super().__init__(f"PAR{k}", use_time=UseTime.CPU_TIME, post_process=penalise)