# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, numpy as np
plt = k1lib.dep.plt
from functools import partial
__all__ = ["ParamFinder"]
[docs]@k1lib.patch(Cbs)
class ParamFinder(Callback): # ParamFinder
" " # ParamFinder
[docs] def __init__(self, tolerance:float=10): # ParamFinder
"""Automatically finds out the right value for a specific parameter.
:param tolerance: how much higher should the loss be to be considered a failure?""" # ParamFinder
super().__init__(); self.order = 23 # ParamFinder
self.suspended = True; self.losses = []; self.tolerance = tolerance # ParamFinder
@property # ParamFinder
def samples(self): return self._samples # ParamFinder
@samples.setter # ParamFinder
def samples(self, samples): # ParamFinder
self._samples = samples # ParamFinder
self.potentialValues = 10**np.linspace(-6, 2, samples) # ParamFinder
@property # ParamFinder
def value(self): # ParamFinder
if self.idx >= len(self.potentialValues): raise k1lib.CancelRunException("Checked all possible param values") # ParamFinder
return self.potentialValues[self.idx] # ParamFinder
@property # ParamFinder
def lossAvgs(self): return sum(self.losses[-2:])/2 # ParamFinder
[docs] def startBatch(self): # ParamFinder
self.idx += 1 # ParamFinder
for paramGroup in self.l.opt.param_groups: # ParamFinder
paramGroup[self.param] = self.value # ParamFinder
@property # ParamFinder
def suggestedValue(self): # ParamFinder
"""The suggested param value. Has to :meth:`run` first, before
this value exists""" # ParamFinder
return self.best/2 # ParamFinder
[docs] def endLoss(self): # ParamFinder
self.losses.append(self.l.loss) # ParamFinder
lossAvgs = self.lossAvgs # ParamFinder
if lossAvgs < self.bestLoss: # ParamFinder
self.best = self.value # ParamFinder
self.bestLoss = lossAvgs # ParamFinder
if lossAvgs > self.bestLoss * self.tolerance: raise k1lib.CancelRunException("Loss increases significantly") # ParamFinder
def __repr__(self): # ParamFinder
return f"""{self._reprHead}, use...
- pf.run(): to start scanning for good params and automatically plots
- pf.plot(): to plot
- pf.samples = ...: to set how many param values to iterate through
{self._reprCan}""" # ParamFinder
@k1lib.patch(ParamFinder) # ParamFinder
def run(self, param:str="lr", samples:int=300) -> float: # run
"""Finds the optimin param value.
:param samples: how many samples to test between :math:`10^{-6}` to :math:`10^2`
:return: the suggested param value""" # run
self.param = param; self.samples = samples # run
self.idx = 0; self.losses = []; self.best = None; self.bestLoss = float("inf") # run
with self.cbs.suspendEval(less=["ProgressBar"]), self.l.model.paramsContext(): # run
self.suspended = False; self.l.run(int(1e3)); self.suspended = True # run
return self.suggestedValue # run
def plotF(self, _slice): # plotF
r = k1lib.Range(len(self.losses)).fromUnit(_slice) # plotF
plt.plot(self.potentialValues[r.slice_], self.losses[r.slice_]) # plotF
plt.xscale("log"); plt.xlabel(self.param); plt.ylabel("Loss") # plotF
@k1lib.patch(ParamFinder) # plotF
def plot(self, *args, **kwargs): # plot
"""Plots loss at different param scales. Automatically :meth:`run`
if hasn't, returns a :class:`k1lib.viz.SliceablePlot`.
:param args: Arguments to pass through to :meth:`run` if a run is
required. Just for convenience sake""" # plot
if len(self.losses) == 0: self.run(*args, **kwargs) # plot
print(f"Suggested param: {self.suggestedValue}"); plt.figure(dpi=120) # plot
return k1lib.viz.SliceablePlot(partial(plotF, self), docs="\n\nReminder: slice range here is actually [0, 1], because it's kinda hard to slice the normal way") # plot