# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, numpy as np, math
import k1lib.cli as cli
from functools import partial
plt = k1lib.dep.plt
from typing import Callable
__all__ = ["Loss", "Accuracy"]
def plotF(losses, f): # actual function stored by the sliceable plot # plotF
plt.figure(figsize=(10, 3), dpi=100); f = f | cli.deref() # plotF
try: # plotF
plt.subplot(1, 2, 1); plt.plot(range(len(losses.train)) | f, losses.train | f); plt.title(f"Train loss") # plotF
plt.subplot(1, 2, 2); plt.plot(range(len(losses.valid)) | f, losses.valid | f); plt.title(f"Valid loss") # plotF
except: pass # plotF
def commonPlot(obj, f=cli.iden()): # commonPlot
plotF(obj, f); return # commonPlot
return k1lib.viz.SliceablePlot(partial(plotF, obj, f), docs="""\n\nReminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame""") # commonPlot
def nonEmptyList(_list): # nonEmptyList
return [0] if _list == [] else _list # nonEmptyList
[docs]@k1lib.patch(Cbs) # nonEmptyList
class Loss(Callback): # Loss
" " # Loss
[docs] def __init__(self, f=lambda l: l.loss): # Loss
"""Records losses after each batch.
Expected variables in :class:`~k1lib.Learner`:
- loss: single float value
:param f: optional function to get the loss from :class:`~k1lib.Learner` object""" # Loss
super().__init__(); self.order = 20; self.f = f # Loss
self.train = []; self.valid = [] # all stats all times # Loss
# average stats for each epoch # Loss
self.epoch = k1lib.Object.fromDict({"train": [], "valid": []})\
.withRepr("Use...\n" +\
"- `.train` for epoch-averaged training losses\n" +\
"- `.valid` for epoch-averaged validation losses\n" +\
"- `.plot()` to plot the 2 above") # Loss
self.plot = partial(commonPlot, self) # Loss
self.epoch.plot = partial(commonPlot, self.epoch) # Loss
self._trainLosses = []; self._validLosses = [] # Loss
self._landscape = k1lib.callbacks.Landscape(lambda l: l.loss, "_LossLandscape") # Loss
[docs] def endLoss(self): # Loss
loss = self.f(self.l) # Loss
if self.l.model.training: self._trainLosses.append(loss) # Loss
else: self._validLosses.append(loss) # Loss
[docs] def endEpoch(self): # Loss
self.train.extend(self._trainLosses); self.epoch.train.append(np.mean(nonEmptyList(self._trainLosses))) # Loss
self.valid.extend(self._validLosses); self.epoch.valid.append(np.mean(nonEmptyList(self._validLosses))) # Loss
self._trainLosses = []; self._validLosses = [] # Loss
@property # Loss
def Landscape(self): # Loss
"""Gets loss-landscape-plotting Callback.
Example::
l = k1lib.Learner.sample()
l.cbs.add(Cbs.Loss())
l.Loss.Landscape.plot()""" # Loss
self.cbs.add(self._landscape); return self._landscape # Loss
[docs] def detach(self): self._landscape.detach(); return super().detach() # Loss
[docs] def clear(self): # Loss
"""Clears saved data""" # Loss
self.train = []; self.epoch.train = [] # Loss
self.valid = []; self.epoch.valid = [] # Loss
def __repr__(self): # Loss
return f"""{super()._reprHead}, use...
- cb.train: for all training losses over all epochs and batches (#epochs * #batches)
- cb.valid: for all validation losses over all epochs and batches (#epochs * #batches)
- cb.plot(): to plot the 2 above
- cb.clear(): to clear saved data
- cb.epoch: for average losses of each epochs
- cb.Landscape: for loss-landscape-plotting Callback
{super()._reprCan}""" # Loss
accFMsg = "You have to specify how to compute the accuracy with the AccF callback first" # Loss
[docs]@k1lib.patch(Cbs) # Loss
class Accuracy(Callback): # Accuracy
" " # Accuracy
[docs] def __init__(self, variable:str="accuracy"): # Accuracy
"""Records accuracies after each batch.
Expected variables in :class:`~k1lib.Learner`:
- accuracy: single float value from 0 to 1
:param variable: name of variable expected to be available in Learner""" # Accuracy
super().__init__(); self.order = 20 # Accuracy
self.train = [0]; self.valid = [0]; self.paused = True; self.variable = variable # Accuracy
self._landscape = k1lib.callbacks.Landscape(lambda l: l.__dict__[variable], "_AccuracyLandscape") # Accuracy
@property # Accuracy
def hasAccF(self): # Accuracy
return any(isinstance(cb, Cbs.AccF) for cb in self.l.cbs.cbs) # Accuracy
[docs] def startRun(self): # Accuracy
self.paused = not self.hasAccF # Accuracy
if not self.paused: # Accuracy
self.train = list(self.train); self.valid = list(self.valid) # Accuracy
[docs] def endRun(self): # Accuracy
if not self.paused: # Accuracy
self.train = np.array(self.train); self.valid = np.array(self.valid) # Accuracy
[docs] def endLoss(self): # Accuracy
if not self.paused: # Accuracy
(self.train if self.l.model.training else self.valid).append(self.l.__dict__[self.variable]) # Accuracy
[docs] def plot(self, f=cli.iden()): # Accuracy
"""
:param f:Optional post-processing cli""" # Accuracy
if not self.hasAccF: raise RuntimeError(accFMsg) # Accuracy
plt.figure(figsize=(10, 3), dpi=100); f = f | cli.deref() # Accuracy
try: # Accuracy
plt.subplot(1, 2, 1); plt.plot(range(len(self.train)) | f, 100*self.train | f); plt.title(f"Train accuracy") # Accuracy
plt.subplot(1, 2, 2); plt.plot(range(len(self.valid)) | f, 100*self.valid | f); plt.title(f"Valid accuracy") # Accuracy
except: pass # Accuracy
@property # Accuracy
def Landscape(self): # Accuracy
"""Gets accuracy-landscape-plotting Callback.
Example::
l = k1lib.Learner.sample()
l.add(Cbs.Accuracy())
l.Accuracy.Landscape.plot()
This exact example won't work, as the sample :class:`~k1lib.Learner` task is not
categorical, but the general idea still stands""" # Accuracy
if self.hasAccF: # Accuracy
self._landscape.parent = self # Accuracy
self.cbs.add(self._landscape); return self._landscape # Accuracy
else: raise RuntimeError(f"{accFMsg}, before you can view the landscape") # Accuracy
[docs] def clear(self): # Accuracy
"""Clears saved data.""" # Accuracy
self.train = [0]; self.valid = [0] # Accuracy
def __repr__(self): # Accuracy
return f"""{super()._reprHead}{f" (.accuracyF not defined yet)" if not self.hasAccF else ""}, use...
- a.train: for train accuracies over all batches
- a.valid: for train accuracies over all batches
- a.plot(): to plot the 2 above
- a.clear(): to clear saved data
- a.Landscape: for loss-landscape-plotting Callback
{super()._reprCan}""" # Accuracy