Source code for k1lib.callbacks.paramFinder

# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, numpy as np
plt = k1lib.dep.plt
from functools import partial
__all__ = ["ParamFinder"]
[docs]@k1lib.patch(Cbs) class ParamFinder(Callback): # ParamFinder " " # ParamFinder
[docs] def __init__(self, tolerance:float=10): # ParamFinder """Automatically finds out the right value for a specific parameter. :param tolerance: how much higher should the loss be to be considered a failure?""" # ParamFinder super().__init__(); self.order = 23 # ParamFinder self.suspended = True; self.losses = []; self.tolerance = tolerance # ParamFinder
@property # ParamFinder def samples(self): return self._samples # ParamFinder @samples.setter # ParamFinder def samples(self, samples): # ParamFinder self._samples = samples # ParamFinder self.potentialValues = 10**np.linspace(-6, 2, samples) # ParamFinder @property # ParamFinder def value(self): # ParamFinder if self.idx >= len(self.potentialValues): raise k1lib.CancelRunException("Checked all possible param values") # ParamFinder return self.potentialValues[self.idx] # ParamFinder @property # ParamFinder def lossAvgs(self): return sum(self.losses[-2:])/2 # ParamFinder
[docs] def startBatch(self): # ParamFinder self.idx += 1 # ParamFinder for paramGroup in self.l.opt.param_groups: # ParamFinder paramGroup[self.param] = self.value # ParamFinder
@property # ParamFinder def suggestedValue(self): # ParamFinder """The suggested param value. Has to :meth:`run` first, before this value exists""" # ParamFinder return self.best/2 # ParamFinder
[docs] def endLoss(self): # ParamFinder self.losses.append(self.l.loss) # ParamFinder lossAvgs = self.lossAvgs # ParamFinder if lossAvgs < self.bestLoss: # ParamFinder self.best = self.value # ParamFinder self.bestLoss = lossAvgs # ParamFinder if lossAvgs > self.bestLoss * self.tolerance: raise k1lib.CancelRunException("Loss increases significantly") # ParamFinder
def __repr__(self): # ParamFinder return f"""{self._reprHead}, use... - pf.run(): to start scanning for good params and automatically plots - pf.plot(): to plot - pf.samples = ...: to set how many param values to iterate through {self._reprCan}""" # ParamFinder
@k1lib.patch(ParamFinder) # ParamFinder def run(self, param:str="lr", samples:int=300) -> float: # run """Finds the optimin param value. :param samples: how many samples to test between :math:`10^{-6}` to :math:`10^2` :return: the suggested param value""" # run self.param = param; self.samples = samples # run self.idx = 0; self.losses = []; self.best = None; self.bestLoss = float("inf") # run with self.cbs.suspendEval(less=["ProgressBar"]), self.l.model.paramsContext(): # run self.suspended = False; self.l.run(int(1e3)); self.suspended = True # run return self.suggestedValue # run def plotF(self, _slice): # plotF r = k1lib.Range(len(self.losses)).fromUnit(_slice) # plotF plt.plot(self.potentialValues[r.slice_], self.losses[r.slice_]) # plotF plt.xscale("log"); plt.xlabel(self.param); plt.ylabel("Loss") # plotF @k1lib.patch(ParamFinder) # plotF def plot(self, *args, **kwargs): # plot """Plots loss at different param scales. Automatically :meth:`run` if hasn't, returns a :class:`k1lib.viz.SliceablePlot`. :param args: Arguments to pass through to :meth:`run` if a run is required. Just for convenience sake""" # plot if len(self.losses) == 0: self.run(*args, **kwargs) # plot print(f"Suggested param: {self.suggestedValue}"); plt.figure(dpi=120) # plot return k1lib.viz.SliceablePlot(partial(plotF, self), docs="\n\nReminder: slice range here is actually [0, 1], because it's kinda hard to slice the normal way") # plot