Source code for k1lib.callbacks.callbacks

# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""
Bare example of how this module works::

    import k1lib

    class CbA(k1lib.Callback):
        def __init__(self):
            super().__init__()
            self.initialState = 3
        def startBatch(self):
            print("startBatch - CbA")
        def startPass(self):
            print("startPass - CbA")

    class CbB(k1lib.Callback):
        def startBatch(self):
            print("startBatch - CbB")
        def endLoss(self):
            print("endLoss - CbB")

    # initialization
    cbs = k1lib.Callbacks()
    cbs.add(CbA()).add(CbB())
    model = lambda xb: xb + 3
    lossF = lambda y, yb: y - yb

    # training loop
    cbs("startBatch"); xb = 6; yb = 2
    cbs("startPass"); y = model(xb); cbs("endPass")
    cbs("startLoss"); loss = lossF(y, yb); cbs("endLoss")
    cbs("endBatch")

    print(cbs.CbA) # can reference the Callback object directly

So, point is, you can define lots of :class:`Callback` classes that
defines a number of checkpoint functions, like ``startBatch``. Then,
you can create a :class:`Callbacks` object that includes Callback
objects. When you do ``cbs("checkpoint")``, this will execute
``cb.checkpoint()`` of all the Callback objects.

Pretty much everything here is built upon this. The core training loop
has nothing to do with ML stuff. In fact, it's just a bunch of
``cbs("...")`` statements. Everything meaningful about the training
loop comes from different Callback classes. Advantage of this is that you
can tack on wildly different functions, have them play nicely with each
other, and remove entire complex functionalities by commenting out a
single line."""
import k1lib, time, os, logging, numpy as np
plt = k1lib.dep.plt; import k1lib.cli as cli
from typing import Set, List, Union, Callable, ContextManager, Iterator
from collections import OrderedDict
__all__ = ["Callback", "Callbacks", "Cbs"]
[docs]class Callback: # Callback r"""Represents a callback. Define specific functions inside to intercept certain parts of the training loop. Can access :class:`k1lib.Learner` like this:: self.l.xb = self.l.xb[None] This takes x batch of learner, unsqueeze it at the 0 position, then sets the x batch again. Normally, you will define a subclass of this and define specific intercept functions, but if you want to create a throwaway callback, then do this:: Callback().withCheckpoint("startRun", lambda: print("start running")) You can use :attr:`~k1lib.callbacks.callbacks.Cbs` (automatically exposed) for a list of default Callback classes, for any particular needs. **order** You can also use `.order` to set the order of execution of the callback. The higher, the later it gets executed. Value suggestions: - 7: pre-default runs, like LossLandscape - 10: default runs, like DontTrainValid - 13: custom mods, like ModifyBatch - 15: pre-recording mod - 17: recording mods, like Profiler.memory - 20: default recordings, like Loss - 23: post-default recordings, like ParamFinder - 25: guards, like TimeLimit, CancelOnExplosion Just leave as default (10) if you don't know what values to choose. **dependsOn** If you're going to extend this class, you can also specify dependencies like this:: class CbC(k1lib.Callback): def __init__(self): super().__init__() self.dependsOn = {"Loss", "Accuracy"} This is so that if somewhere, ``Loss`` callback class is temporarily suspended, then CbC will be suspended also, therefore avoiding errors. **Suspension** If your Callback is mainly dormant, then you can do something like this:: class CbD(k1lib.Callback): def __init__(self): super().__init__() self.suspended = True def startBatch(self): # these types of methods will only execute # if ``self.suspended = False`` pass def analyze(self): self.suspended = False # do something that sometimes call ``startBatch`` self.suspended = True cbs = k1lib.Callbacks().add(CbD()) # dormant phase: cbs("startBatch") # does not execute CbD.startBatch() # active phase cbs.CbB.analyze() # does execute CbD.startBatch() So yeah, you can easily make every checkpoint active/dormant by changing a single variable, how convenient. See over :meth:`Callbacks.suspend` for more.""" # Callback def __init__(self): # Callback self.l = None; self.cbs = None; self.suspended = False # Callback self.name = self.__class__.__name__; self.dependsOn:Set[str] = set() # Callback self.order = 10 # can be modified by subclasses. A smaller order will be executed first # Callback
[docs] def suspend(self): # Callback """Checkpoint, called when the Callback is temporarily suspended. Overridable""" # Callback pass # Callback
[docs] def restore(self): # Callback """Checkpoint, called when the Callback is back from suspension. Overridable""" # Callback pass # Callback
def __getstate__(self): state = dict(self.__dict__); state.pop("l", None); state.pop("cbs", None); return state # Callback def __setstate__(self, state): self.__dict__.update(state) # Callback def __repr__(self): return f"{self._reprHead}, can...\n{self._reprCan}" # Callback @property # Callback def _reprHead(self): return f"Callback `{self.name}`" # Callback @property # Callback def _reprCan(self): return """- cb.something: to get specific attribute "something" from learner if not available - cb.withCheckpoint(checkpoint, f): to quickly insert an event handler - cb.detach(): to remove itself from its parent Callbacks""" # Callback
[docs] def withCheckpoint(self, checkpoint:str, f:Callable[["Callback"], None]): # Callback """Quickly set a checkpoint, for simple, inline-able functions :param checkpoint: checkpoints like "startRun" :param f: function that takes in the Callback itself""" # Callback setattr(self, checkpoint, lambda: f(self)); return self # Callback
def __call__(self, checkpoint): # Callback if not self.suspended and hasattr(self, checkpoint): # Callback return getattr(self, checkpoint)() != None # Callback
[docs] def attached(self): # Callback """Called when this is added to a :class:`Callback`. Overrides this to do custom stuff when this happens.""" # Callback pass # Callback
[docs] def detach(self): # Callback """Detaches from the parent :class:`Callbacks`""" # Callback self.cbs.remove(self.name); return self # Callback
Cbs = k1lib.Object() # Callback Callback.lossCls = k1lib.Object() # Callback
[docs]class Timings: # Timings """List of checkpoint timings. Not intended to be instantiated by the end user. Used within :class:`~k1lib.callbacks.callbacks.Callbacks`, accessible via :attr:`Callbacks.timings` to record time taken to execute a single checkpoint. This is useful for profiling stuff.""" # Timings @property # Timings def state(self): # Timings answer = dict(self.__dict__); answer.pop("getdoc", None); return answer # Timings @property # Timings def checkpoints(self) -> List[str]: # Timings """List of all checkpoints encountered""" # Timings return [cp for cp in self.state if k1lib.isNumeric(self[cp])] # Timings def __getattr__(self, attr): # Timings if attr.startswith("_"): raise AttributeError() # Timings self.__dict__[attr] = 0; return 0 # Timings def __getitem__(self, idx): return getattr(self, idx) # Timings def __setitem__(self, idx, value): setattr(self, idx, value) # Timings
[docs] def plot(self): # Timings """Plot all checkpoints' execution times""" # Timings plt.figure(dpi=100); checkpoints = self.checkpoints # Timings timings = np.array([self[cp] for cp in checkpoints]) # Timings maxTiming = timings.max() # Timings if maxTiming >= 1: # Timings plt.bar(checkpoints, timings); plt.ylabel("Time (s)") # Timings elif maxTiming >= 1e-3 and maxTiming < 1: # Timings plt.bar(checkpoints, timings*1e3); plt.ylabel("Time (ms)") # Timings elif maxTiming >= 1e-6 and maxTiming < 1e-3: # Timings plt.bar(checkpoints, timings*1e6); plt.ylabel("Time (us)") # Timings plt.xticks(rotation="vertical"); plt.show() # Timings
[docs] def clear(self): # Timings """Clears all timing data""" # Timings for cp in self.checkpoints: self[cp] = 0 # Timings
def __repr__(self): # Timings cps = '\n'.join([f'- {cp}: {self[cp]}' for cp in self.checkpoints]) # Timings return f"""Timings object. Checkpoints:\n{cps}\n Can... - t.startRun: to get specific checkpoint's execution time - t.plot(): to plot all checkpoints""" # Timings
_time = time.time # Timings
[docs]class Callbacks: # Callbacks def __init__(self): # Callbacks self._l: k1lib.Learner = None; self.cbsDict = {} # Callbacks self._timings = Timings(); self.contexts = [[]] # Callbacks @property # Callbacks def timings(self) -> Timings: # Callbacks """Returns :class:`~k1lib.callbacks.callbacks.Timings` object""" # Callbacks return self._timings # Callbacks @property # Callbacks def l(self) -> "k1lib.Learner": # Callbacks """:class:`k1lib.Learner` object. Will be set automatically when you set :attr:`k1lib.Learner.cbs` to this :class:`Callbacks`""" # Callbacks return self._l # Callbacks @l.setter # Callbacks def l(self, learner): # Callbacks self._l = learner # Callbacks for cb in self.cbs: cb.l = learner # Callbacks @property # Callbacks def cbs(self) -> List[Callback]: # Callbacks """List of :class:`Callback`""" # Callbacks return [*self.cbsDict.values()] # convenience method for looping over stuff # Callbacks def _sort(self) -> "Callbacks": # Callbacks self.cbsDict = OrderedDict(sorted(self.cbsDict.items(), key=(lambda o: o[1].order))); return self # Callbacks
[docs] def add(self, cb:Callback, name:str=None): # Callbacks """Adds a callback to the collection. Example:: cbs = k1lib.Callbacks() cbs.add(k1lib.Callback().withCheckpoint("startBatch", lambda self: print("here"))) If you just want to insert a simple callback with a single checkpoint, then you can do something like:: cbs.add(["startBatch", lambda _: print("here")])""" # Callbacks if isinstance(cb, (list, tuple)): # Callbacks return self.add(Callback().withCheckpoint(cb[0], cb[1])) # Callbacks if not isinstance(cb, Callback): raise RuntimeError("`cb` is not a callback!") # Callbacks if cb in self.cbs: cb.l = self.l; cb.cbs = self; return self # Callbacks cb.l = self.l; cb.cbs = self; name = name or cb.name # Callbacks if name in self.cbsDict: # Callbacks i = 0 # Callbacks while f"{name}{i}" in self.cbsDict: i += 1 # Callbacks name = f"{name}{i}" # Callbacks cb.name = name; self.cbsDict[name] = cb; self._sort() # Callbacks self._appendContext_append(cb); cb("attached"); return self # Callbacks
[docs] def __contains__(self, e:str) -> bool: # Callbacks """Whether a specific Callback name is in this :class:`Callback`.""" # Callbacks return e in self.cbsDict # Callbacks
[docs] def remove(self, *names:List[str]): # Callbacks """Removes a callback from the collection.""" # Callbacks for name in names: # Callbacks if name not in self.cbsDict: return print(f"Callback `{name}` not found") # Callbacks cb = self.cbsDict[name]; del self.cbsDict[name]; cb("detached") # Callbacks self._sort(); return self # Callbacks
[docs] def removePrefix(self, prefix:str): # Callbacks """Removes any callback with the specified prefix""" # Callbacks for cb in self.cbs: # Callbacks if cb.name.startswith(prefix): self.remove(cb.name) # Callbacks return self # Callbacks
[docs] def __call__(self, *checkpoints:List[str]) -> bool: # Callbacks """Calls a number of checkpoints one after another. Returns True if any of the checkpoints return anything at all""" # Callbacks self._checkpointGraph_call(checkpoints) # Callbacks answer = False # Callbacks for checkpoint in checkpoints: # Callbacks beginTime = _time() # Callbacks answer |= any([cb(checkpoint) for cb in self.cbs]) # Callbacks self._timings[checkpoint] += _time() - beginTime # Callbacks return answer # Callbacks
[docs] def __getitem__(self, idx:Union[int, str]) -> Callback: # Callbacks """Get specific cbs. :param idx: if :class:`str`, then get the Callback with this specific name, if :class:`int`, then get the Callback in that index.""" # Callbacks return self.cbs[idx] if isinstance(idx, int) else self.cbsDict[idx] # Callbacks
[docs] def __iter__(self) -> Iterator[Callback]: # Callbacks """Iterates through all :class:`Callback`.""" # Callbacks for cb in self.cbsDict.values(): yield cb # Callbacks
[docs] def __len__(self): # Callbacks """How many :class:`Callback` are there in total?""" # Callbacks return len(self.cbsDict) # Callbacks
def __getattr__(self, attr): # Callbacks if attr == "cbsDict": raise AttributeError(attr) # Callbacks if attr in self.cbsDict: return self.cbsDict[attr] # Callbacks else: raise AttributeError(attr) # Callbacks def __getstate__(self): # Callbacks state = dict(self.__dict__); state.pop("_l", None); return state # Callbacks def __setstate__(self, state): # Callbacks self.__dict__.update(state) # Callbacks for cb in self.cbs: cb.cbs = self # Callbacks def __dir__(self): # Callbacks answer = list(super().__dir__()) # Callbacks answer.extend(self.cbsDict.keys()) # Callbacks return answer # Callbacks def __repr__(self): # Callbacks return "Callbacks:\n" + '\n'.join([f"- {cbName}" for cbName in self.cbsDict if not cbName.startswith("_")]) + """\n Use... - cbs.add(cb[, name]): to add a callback with a name - cbs("startRun"): to trigger a specific checkpoint, this case "startRun" - cbs.Loss: to get a specific callback by name, this case "Loss" - cbs[i]: to get specific callback by index - cbs.timings: to get callback execution times - cbs.checkpointGraph(): to graph checkpoint calling orders - cbs.context(): context manager that will detach all Callbacks attached inside the context - cbs.suspend("Loss", "Cuda"): context manager to temporarily prevent triggering checkpoints""" # Callbacks
[docs] def withBasics(self): # Callbacks """Adds a bunch of very basic Callbacks that's needed for everything. Also includes Callbacks that are not necessary, but don't slow things down""" # Callbacks self.add(Cbs.CoreNormal()).add(Cbs.Profiler()).add(Cbs.Recorder()) # Callbacks self.add(Cbs.ProgressBar()).add(Cbs.Loss()).add(Cbs.Accuracy()).add(Cbs.DontTrainValid()) # Callbacks return self.add(Cbs.CancelOnExplosion()).add(Cbs.ParamFinder()) # Callbacks
[docs] def withQOL(self): # Callbacks """Adds quality of life Callbacks.""" # Callbacks return self # Callbacks
[docs] def withAdvanced(self): # Callbacks """Adds advanced Callbacks that do fancy stuff, but may slow things down if not configured specifically.""" # Callbacks return self.add(Cbs.HookModule().withMeanRecorder().withStdRecorder()).add(Cbs.HookParam()) # Callbacks
@k1lib.patch(Callbacks) # Callbacks def _resolveDependencies(self): # _resolveDependencies for cb in self.cbs: # _resolveDependencies cb._dependents:Set[Callback] = set() # _resolveDependencies cb.dependsOn = set(cb.dependsOn) # _resolveDependencies for cb in self.cbs: # _resolveDependencies for cb2 in self.cbs: # _resolveDependencies if cb2.__class__.__name__ in cb.dependsOn: # _resolveDependencies cb2._dependents.add(cb) # _resolveDependencies class SuspendContext: # SuspendContext def __init__(self, cbs:Callbacks, cbsNames:List[str], cbsClasses:List[str]): # SuspendContext self.cbs = cbs; self.cbsNames = cbsNames; self.cbsClasses = cbsClasses # SuspendContext self.cbs.suspendStack = getattr(self.cbs, "suspendStack", []) # SuspendContext def __enter__(self): # SuspendContext cbsClasses = set(self.cbsClasses); cbsNames = set(self.cbsNames) # SuspendContext self._resolveDependencies() # SuspendContext def explore(cb:Callback): # SuspendContext for dept in cb._dependents: # SuspendContext cbsClasses.add(dept.__class__.__name__); explore(dept) # SuspendContext [explore(cb) for cb in self.cbs if cb.__class__.__name__ in cbsClasses or cb.name in cbsNames] # SuspendContext stackFrame = {cb:cb.suspended for cb in self.cbs if cb.__class__.__name__ in cbsClasses or cb.name in cbsNames} # SuspendContext for cb in stackFrame: cb.suspend(); cb.suspended = True # SuspendContext self.suspendStack.append(stackFrame) # SuspendContext def __exit__(self, *ignored): # SuspendContext for cb, oldValue in self.suspendStack.pop().items(): # SuspendContext cb.suspended = oldValue; cb.restore() # SuspendContext def __getattr__(self, attr): return getattr(self.cbs, attr) # SuspendContext @k1lib.patch(Callbacks) # SuspendContext def suspend(self, *cbNames:List[str]) -> ContextManager: # suspend """Creates suspension context for specified Callbacks. Matches callbacks with their name. Works like this:: cbs = k1lib.Callbacks().add(CbA()).add(CbB()).add(CbC()) with cbs.suspend("CbA", "CbC"): pass # inside here, only CbB will be active, and its checkpoints executed # CbA, CbB and CbC are all active .. seealso:: :meth:`suspendClasses`""" # suspend return SuspendContext(self, cbNames, []) # suspend @k1lib.patch(Callbacks) # suspend def suspendClasses(self, *classNames:List[str]) -> ContextManager: # suspendClasses """Like :meth:`suspend`, but matches callbacks' class names to the given list, instead of matching names. Meaning:: cbs.k1lib.Callbacks().add(Cbs.Loss()).add(Cbs.Loss()) # cbs now has 2 callbacks "Loss" and "Loss0" with cbs.suspendClasses("Loss"): pass # now both of them are suspended""" # suspendClasses return SuspendContext(self, [], classNames) # suspendClasses @k1lib.patch(Callbacks) # suspendClasses def suspendEval(self, more:List[str]=[], less:List[str]=[]) -> ContextManager: # suspendEval """Same as :meth:`suspendClasses`, but suspend some default classes typical used for evaluation callbacks. Just convenience method really. Currently includes: - HookModule, HookParam, ProgressBar - ParamScheduler, Loss, Accuracy, Autosave - ConfusionMatrix :param more: include more classes to be suspended :param less: exclude classes supposed to be suspended by default""" # suspendEval classes = {"HookModule", "HookParam", "ProgressBar", "ParamScheduler", "Loss", "Accuracy", "Autosave", "ConfusionMatrix"} # suspendEval classes.update(more); classes -= set(less) # suspendEval return self.suspendClasses(*classes) # suspendEval class AppendContext: # AppendContext def __init__(self, cbs:Callbacks, initCbs:List[Callback]=[]): # AppendContext self.cbs = cbs; self.initCbs = initCbs # AppendContext def __enter__(self): # AppendContext self.cbs.contexts.append([]) # AppendContext for cb in self.initCbs: self.cbs.add(cb) # AppendContext return self.cbs # AppendContext def __exit__(self, *ignored): # AppendContext [cb.detach() for cb in self.cbs.contexts.pop()] # AppendContext @k1lib.patch(Callbacks) # AppendContext def _appendContext_append(self, cb): # _appendContext_append self.contexts[-1].append(cb) # _appendContext_append @k1lib.patch(Callbacks) # _appendContext_append def context(self, *initCbs:List[Callback]) -> ContextManager: # context """Add context. Works like this:: cbs = k1lib.Callbacks().add(CbA()) # CbA is available with cbs.context(CbE(), CbF()): cbs.add(CbB()) # CbA, CbB, CbE and CbF available cbs.add(CbC()) # all 5 are available # only CbA is available For maximum shortness, you can do this:: with k1lib.Callbacks().context(CbA()) as cbs: # Cba is available """ # context return AppendContext(self, initCbs) # context @k1lib.patch(Callbacks) # context def _checkpointGraph_call(self, checkpoints:List[str]): # _checkpointGraph_call if not hasattr(self, "_checkpointGraphDict"): # _checkpointGraph_call self._checkpointGraphDict = k1lib.Object().withAutoDeclare(lambda: k1lib.Object().withAutoDeclare(lambda: 0)) # _checkpointGraph_call self._lastCheckpoint = "<root>" # _checkpointGraph_call for cp in checkpoints: # _checkpointGraph_call self._checkpointGraphDict[self._lastCheckpoint][cp] += 1 # _checkpointGraph_call self._lastCheckpoint = cp # _checkpointGraph_call @k1lib.patch(Callbacks) # _checkpointGraph_call def checkpointGraph(self, highlightCb:Union[str, Callback]=None): # checkpointGraph """Graphs what checkpoints follows what checkpoints. Has to run at least once first. Requires graphviz package though. Example:: cbs = Callbacks() cbs("a", "b", "c", "d", "b") cbs.checkpointGraph() # returns graph object. Will display image if using notebooks .. image:: ../images/checkpointGraph.png :param highlightCb: if available, will highlight the checkpoints the callback uses. Can be name/class-name/class/self of callback.""" # checkpointGraph g = k1lib.digraph(); s = set() # checkpointGraph for cp1, cp1o in self._checkpointGraphDict.state.items(): # checkpointGraph for cp2, v in cp1o.state.items(): # checkpointGraph g.edge(cp1, cp2, label=f" {v} "); s.add(cp2) # checkpointGraph if highlightCb != None: # checkpointGraph _cb = None # checkpointGraph if isinstance(highlightCb, Callback): _cb = highlightCb # checkpointGraph elif isinstance(highlightCb, type) and issubclass(highlightCb, Callback): # find cb that has the same class # checkpointGraph for cbo in self.cbs: # checkpointGraph if isinstance(cbo, highlightCb): _cb = cbo; break # checkpointGraph if _cb is None: raise AttributeError(f"Can't find any Callback inside this Callbacks which is of type `{cb.__name__}`") # checkpointGraph elif isinstance(highlightCb, str): # checkpointGraph for cbName, cbo in self.cbsDict.items(): # checkpointGraph if cbName == highlightCb: _cb = cbo; break # checkpointGraph if type(cbo).name == highlightCb: _cb = cbo; break # checkpointGraph if _cb is None: raise AttributeError(f"Can't find any Callback inside this Callbacks with name or class `{cb}`") # checkpointGraph else: raise AttributeError(f"Don't understand {cb}") # checkpointGraph print(f"Highlighting callback `{_cb.name}`, of type `{type(_cb)}`") # checkpointGraph for cp in s: # checkpointGraph if hasattr(_cb, cp): g.node(cp, color="red") # checkpointGraph return g # checkpointGraph