# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, time
try: import torch; hasTorch = True
except: hasTorch = False
__all__ = ["BatchLimit", "EpochLimit", "TimeLimit", "CancelOnExplosion",
"CancelOnLowLoss", "CancelOnHighAccuracy", "CancelOnOverfit", "DontTrain",
"GradientClipping", "GradientClippingNorm", "TrainOnly", "ValidOnly"]
[docs]@k1lib.patch(Cbs)
class BatchLimit(Callback): # BatchLimit
"""Cancels the epoch after executed certain number of batches""" # BatchLimit
def __init__(self, limit:int): # BatchLimit
super().__init__(); self.order = 25 # BatchLimit
self.limit = limit if limit != None else float("inf") # BatchLimit
def startEpoch(self): self.currentBatch = 0 # BatchLimit
def startBatch(self): # BatchLimit
if self.currentBatch >= self.limit: # BatchLimit
raise k1lib.CancelEpochException(f"Batch {self.limit} reached") # BatchLimit
def endBatch(self): self.currentBatch += 1 # BatchLimit
[docs]@k1lib.patch(Cbs) # BatchLimit
class EpochLimit(Callback): # EpochLimit
"""Cancels the run after executed certain number of epochs""" # EpochLimit
def __init__(self, limit:int): # EpochLimit
super().__init__(); self.order = 25 # EpochLimit
self.limit = limit if limit != None else float("inf") # EpochLimit
def startRun(self): self.currentEpoch = 0 # EpochLimit
def startEpoch(self): # EpochLimit
if self.currentEpoch >= self.limit: # EpochLimit
raise k1lib.CancelRunException(f"Epoch {self.limit} reached!") # EpochLimit
def endEpoch(self): self.currentEpoch += 1 # EpochLimit
[docs]@k1lib.patch(Cbs) # EpochLimit
class TimeLimit(Callback): # TimeLimit
"""Cancels the run after a certain number of seconds have passed""" # TimeLimit
def __init__(self, seconds=30): # TimeLimit
super().__init__(); self.seconds = seconds if seconds != None else float("inf"); self.order = 25 # TimeLimit
def startRun(self): self.startTime = time.time() # TimeLimit
def startBatch(self): # TimeLimit
if time.time() - self.startTime > self.seconds: # TimeLimit
raise k1lib.CancelRunException(f"Takes more than {self.seconds} seconds!") # TimeLimit
[docs]@k1lib.patch(Cbs) # TimeLimit
class CancelOnExplosion(Callback): # CancelOnExplosion
"""Cancels the run if any of the parameters are larger than a certain limit""" # CancelOnExplosion
def __init__(self, limit:float=1e6): # CancelOnExplosion
super().__init__(); self.order = 25 # CancelOnExplosion
self.limit = limit; self.triggered = False # CancelOnExplosion
def startRun(self): self.triggered = False # CancelOnExplosion
def startBatch(self): # CancelOnExplosion
for p in self.l.model.parameters(): # CancelOnExplosion
o = p.detach() # CancelOnExplosion
if o.max().float() > self.limit or o.min().float() < -self.limit: # CancelOnExplosion
self.triggered = True # CancelOnExplosion
raise k1lib.CancelRunException("Explosion detected!") # CancelOnExplosion
def __repr__(self): # CancelOnExplosion
return f"""{self._reprHead}, use...
- cb.triggered: to see if there was an explosion on the last run
- cb.progress: to see current progress at explosion time
{self._reprCan}""" # CancelOnExplosion
@k1lib.patch(Cbs) # CancelOnExplosion
class CancelOnLowLoss(Callback): # CancelOnLowLoss
" " # CancelOnLowLoss
def __init__(self, loss:float, epochMode:bool=False): # CancelOnLowLoss
"""Cancels the run if loss is lower than amount specified.
Original class: :class:`~k1lib.callbacks.limits.CancelOnLowLoss`
:param epochMode: False if use batch loss, True if use valid epoch loss""" # CancelOnLowLoss
super().__init__(); self.order = 25; self.dependsOn = ["Loss"] # CancelOnLowLoss
self.loss = loss; self.epochMode = epochMode # CancelOnLowLoss
def startRun(self): # CancelOnLowLoss
if not hasattr(self.l.cbs, "Loss"): # CancelOnLowLoss
raise AttributeError("Learner does not have required `Loss` callback") # CancelOnLowLoss
self.v = self.cbs.Loss.valid; self.ve = self.cbs.Loss.epoch.valid # List[int] # CancelOnLowLoss
def endBatch(self): # CancelOnLowLoss
if self.epochMode: # CancelOnLowLoss
if len(self.ve) > 0 and self.ve[-1] < self.loss: # CancelOnLowLoss
raise k1lib.CancelRunException(f"Low loss {self.loss} ({self.ve[-3:]} actual) achieved!") # CancelOnLowLoss
elif len(self.v) and self.v[-1] < self.loss: # CancelOnLowLoss
raise k1lib.CancelRunException(f"Low loss {self.loss} ({self.v[-3:]} actual) achieved!") # CancelOnLowLoss
[docs]@k1lib.patch(Cbs) # CancelOnLowLoss
class CancelOnHighAccuracy(Callback): # CancelOnHighAccuracy
"""Cancels the run if accuracy is higher than the amount specified""" # CancelOnHighAccuracy
def __init__(self, accuracy:float): # CancelOnHighAccuracy
super().__init__(); self.order = 25 # CancelOnHighAccuracy
self.accuracy = accuracy; self.dependsOn = ["Accuracy"] # CancelOnHighAccuracy
def endBatch(self): # CancelOnHighAccuracy
if not hasattr(self.l, "Accuracy"): raise AttributeError("Learner does not have `Accuracy` callback") # CancelOnHighAccuracy
a = self.l.Accuracy.valid[-1] # CancelOnHighAccuracy
if a > self.accuracy: # CancelOnHighAccuracy
raise k1lib.CancelRunException(f"High accuracy {self.accuracy} ({a} actual) achieved!") # CancelOnHighAccuracy
[docs]@k1lib.patch(Cbs) # CancelOnHighAccuracy
class CancelOnOverfit(Callback): # CancelOnOverfit
[docs] def __init__(self, ratio:float=1.2, alpha:float=0.99, after:int=10): # CancelOnOverfit
"""Cancels the run if overfit is detected.
:param ratio: Max ratio between the lowest loss and the current loss before cancelling the run
:param alpha: Moving average's alpha, used for both minLoss and loss estimates
:param after: After how many epochs should the overfit detection be activated?""" # CancelOnOverfit
super().__init__(); self.ratio = ratio # CancelOnOverfit
self.minLoss = k1lib.MovingAvg(alpha=alpha, debias=True) # CancelOnOverfit
self.loss = k1lib.MovingAvg(alpha=alpha, debias=True) # CancelOnOverfit
self.count = 0; self.after = after # CancelOnOverfit
def startRun(self): self.count = 0 # CancelOnOverfit
def endEpoch(self): self.count += 1 # CancelOnOverfit
def endBatch(self): # CancelOnOverfit
if not self.l.model.training: # CancelOnOverfit
loss = self.l.loss; self.loss(loss) # CancelOnOverfit
if self.loss.value < self.minLoss.value or self.minLoss.value == 0: self.minLoss(self.loss.value) # CancelOnOverfit
if self.count > self.after and self.loss.value > self.minLoss.value * self.ratio: # CancelOnOverfit
raise k1lib.CancelRunException(f"Overfit detected! Smoothed min loss: {self.minLoss.value}, loss: {loss}") # CancelOnOverfit
[docs]@k1lib.patch(Cbs) # CancelOnOverfit
class DontTrain(Callback): # DontTrain
"""Don't allow the network to train at all""" # DontTrain
def startBackward(self): return True # DontTrain
def startStep(self): return True # DontTrain
if hasTorch: # DontTrain
from torch.nn.utils import clip_grad_value_ # DontTrain
@k1lib.patch(Cbs) # DontTrain
class GradientClipping(Callback): # DontTrain
"""Clips gradient to a specific max value""" # DontTrain
def __init__(self, value:float): super().__init__(); self.value = value # DontTrain
def startStep(self): # DontTrain
clip_grad_value_(self.l.model.parameters(), self.value) # DontTrain
else: # DontTrain
[docs] class GradientClipping(Callback): pass # DontTrain
if hasTorch: # DontTrain
from torch.nn.utils import clip_grad_norm_ # DontTrain
@k1lib.patch(Cbs) # DontTrain
class GradientClippingNorm(Callback): # DontTrain
"""Clips gradient to a specific max_norm value. Can choose to lump
all params together or do each separately.
See also: :class:`~k1lib.callbacks.limits.GradientClipping` callback.""" # DontTrain
def __init__(self, max_norm:float, each:bool=True): # DontTrain
super().__init__(); self.max_norm = max_norm; self.each = each # DontTrain
def startStep(self): # DontTrain
if self.each: # DontTrain
for m in self.l.model.parameters(): # DontTrain
clip_grad_norm_(m, self.max_norm) # DontTrain
else: clip_grad_norm_(self.l.model.parameters(), self.max_norm) # DontTrain
else: # DontTrain
[docs] class GradientClippingNorm(Callback): pass # DontTrain
@k1lib.patch(Cbs) # DontTrain
class TrainOnly(Callback): # TrainOnly
" " # TrainOnly
def __init__(self, cb): # TrainOnly
"""Only executes specified callback when training. This modifies the callback's
``suspended`` variable, so it may interfere with :meth:`k1lib.callbacks.callbacks.Callbacks.suspend`
by setting it to different values while in the context.""" # TrainOnly
super().__init__(); self.cb = cb # TrainOnly
def startBatch(self): # TrainOnly
self.cb.suspended = not self.l.model.training # TrainOnly
@k1lib.patch(Cbs) # TrainOnly
class ValidOnly(Callback): # ValidOnly
" " # ValidOnly
def __init__(self, cb): # ValidOnly
"""Same as :class:`TrainOnly`, but only executes specified callback when doing
validation.""" # ValidOnly
super().__init__(); self.cb = cb # ValidOnly
def startBatch(self): # ValidOnly
self.cb.suspended = self.l.model.training # ValidOnly