# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""
Bare example of how this module works::
import k1lib
class CbA(k1lib.Callback):
def __init__(self):
super().__init__()
self.initialState = 3
def startBatch(self):
print("startBatch - CbA")
def startPass(self):
print("startPass - CbA")
class CbB(k1lib.Callback):
def startBatch(self):
print("startBatch - CbB")
def endLoss(self):
print("endLoss - CbB")
# initialization
cbs = k1lib.Callbacks()
cbs.add(CbA()).add(CbB())
model = lambda xb: xb + 3
lossF = lambda y, yb: y - yb
# training loop
cbs("startBatch"); xb = 6; yb = 2
cbs("startPass"); y = model(xb); cbs("endPass")
cbs("startLoss"); loss = lossF(y, yb); cbs("endLoss")
cbs("endBatch")
print(cbs.CbA) # can reference the Callback object directly
So, point is, you can define lots of :class:`Callback` classes that
defines a number of checkpoint functions, like ``startBatch``. Then,
you can create a :class:`Callbacks` object that includes Callback
objects. When you do ``cbs("checkpoint")``, this will execute
``cb.checkpoint()`` of all the Callback objects.
Pretty much everything here is built upon this. The core training loop
has nothing to do with ML stuff. In fact, it's just a bunch of
``cbs("...")`` statements. Everything meaningful about the training
loop comes from different Callback classes. Advantage of this is that you
can tack on wildly different functions, have them play nicely with each
other, and remove entire complex functionalities by commenting out a
single line."""
import k1lib, time, os, logging, numpy as np
plt = k1lib.dep.plt; import k1lib.cli as cli
from typing import Set, List, Union, Callable, ContextManager, Iterator
from collections import OrderedDict
__all__ = ["Callback", "Callbacks", "Cbs"]
[docs]class Callback: # Callback
r"""Represents a callback. Define specific functions
inside to intercept certain parts of the training
loop. Can access :class:`k1lib.Learner` like this::
self.l.xb = self.l.xb[None]
This takes x batch of learner, unsqueeze it at
the 0 position, then sets the x batch again.
Normally, you will define a subclass of this and
define specific intercept functions, but if you
want to create a throwaway callback, then do this::
Callback().withCheckpoint("startRun", lambda: print("start running"))
You can use :attr:`~k1lib.callbacks.callbacks.Cbs` (automatically exposed) for
a list of default Callback classes, for any particular needs.
**order**
You can also use `.order` to set the order of execution of the callback.
The higher, the later it gets executed. Value suggestions:
- 7: pre-default runs, like LossLandscape
- 10: default runs, like DontTrainValid
- 13: custom mods, like ModifyBatch
- 15: pre-recording mod
- 17: recording mods, like Profiler.memory
- 20: default recordings, like Loss
- 23: post-default recordings, like ParamFinder
- 25: guards, like TimeLimit, CancelOnExplosion
Just leave as default (10) if you don't know what values to choose.
**dependsOn**
If you're going to extend this class, you can also specify dependencies
like this::
class CbC(k1lib.Callback):
def __init__(self):
super().__init__()
self.dependsOn = {"Loss", "Accuracy"}
This is so that if somewhere, ``Loss`` callback class is temporarily
suspended, then CbC will be suspended also, therefore avoiding errors.
**Suspension**
If your Callback is mainly dormant, then you can do something like this::
class CbD(k1lib.Callback):
def __init__(self):
super().__init__()
self.suspended = True
def startBatch(self):
# these types of methods will only execute
# if ``self.suspended = False``
pass
def analyze(self):
self.suspended = False
# do something that sometimes call ``startBatch``
self.suspended = True
cbs = k1lib.Callbacks().add(CbD())
# dormant phase:
cbs("startBatch") # does not execute CbD.startBatch()
# active phase
cbs.CbB.analyze() # does execute CbD.startBatch()
So yeah, you can easily make every checkpoint active/dormant by changing
a single variable, how convenient. See over :meth:`Callbacks.suspend`
for more.""" # Callback
def __init__(self): # Callback
self.l = None; self.cbs = None; self.suspended = False # Callback
self.name = self.__class__.__name__; self.dependsOn:Set[str] = set() # Callback
self.order = 10 # can be modified by subclasses. A smaller order will be executed first # Callback
[docs] def suspend(self): # Callback
"""Checkpoint, called when the Callback is temporarily suspended. Overridable""" # Callback
pass # Callback
[docs] def restore(self): # Callback
"""Checkpoint, called when the Callback is back from suspension. Overridable""" # Callback
pass # Callback
def __getstate__(self): state = dict(self.__dict__); state.pop("l", None); state.pop("cbs", None); return state # Callback
def __setstate__(self, state): self.__dict__.update(state) # Callback
def __repr__(self): return f"{self._reprHead}, can...\n{self._reprCan}" # Callback
@property # Callback
def _reprHead(self): return f"Callback `{self.name}`" # Callback
@property # Callback
def _reprCan(self): return """- cb.something: to get specific attribute "something" from learner if not available
- cb.withCheckpoint(checkpoint, f): to quickly insert an event handler
- cb.detach(): to remove itself from its parent Callbacks""" # Callback
[docs] def withCheckpoint(self, checkpoint:str, f:Callable[["Callback"], None]): # Callback
"""Quickly set a checkpoint, for simple, inline-able functions
:param checkpoint: checkpoints like "startRun"
:param f: function that takes in the Callback itself""" # Callback
setattr(self, checkpoint, lambda: f(self)); return self # Callback
def __call__(self, checkpoint): # Callback
if not self.suspended and hasattr(self, checkpoint): # Callback
return getattr(self, checkpoint)() != None # Callback
[docs] def attached(self): # Callback
"""Called when this is added to a :class:`Callback`. Overrides this to
do custom stuff when this happens.""" # Callback
pass # Callback
[docs] def detach(self): # Callback
"""Detaches from the parent :class:`Callbacks`""" # Callback
self.cbs.remove(self.name); return self # Callback
Cbs = k1lib.Object() # Callback
Callback.lossCls = k1lib.Object() # Callback
[docs]class Timings: # Timings
"""List of checkpoint timings. Not intended to be instantiated by the end user.
Used within :class:`~k1lib.callbacks.callbacks.Callbacks`, accessible via
:attr:`Callbacks.timings` to record time taken to execute a single
checkpoint. This is useful for profiling stuff.""" # Timings
@property # Timings
def state(self): # Timings
answer = dict(self.__dict__); answer.pop("getdoc", None); return answer # Timings
@property # Timings
def checkpoints(self) -> List[str]: # Timings
"""List of all checkpoints encountered""" # Timings
return [cp for cp in self.state if k1lib.isNumeric(self[cp])] # Timings
def __getattr__(self, attr): # Timings
if attr.startswith("_"): raise AttributeError() # Timings
self.__dict__[attr] = 0; return 0 # Timings
def __getitem__(self, idx): return getattr(self, idx) # Timings
def __setitem__(self, idx, value): setattr(self, idx, value) # Timings
[docs] def plot(self): # Timings
"""Plot all checkpoints' execution times""" # Timings
plt.figure(dpi=100); checkpoints = self.checkpoints # Timings
timings = np.array([self[cp] for cp in checkpoints]) # Timings
maxTiming = timings.max() # Timings
if maxTiming >= 1: # Timings
plt.bar(checkpoints, timings); plt.ylabel("Time (s)") # Timings
elif maxTiming >= 1e-3 and maxTiming < 1: # Timings
plt.bar(checkpoints, timings*1e3); plt.ylabel("Time (ms)") # Timings
elif maxTiming >= 1e-6 and maxTiming < 1e-3: # Timings
plt.bar(checkpoints, timings*1e6); plt.ylabel("Time (us)") # Timings
plt.xticks(rotation="vertical"); plt.show() # Timings
[docs] def clear(self): # Timings
"""Clears all timing data""" # Timings
for cp in self.checkpoints: self[cp] = 0 # Timings
def __repr__(self): # Timings
cps = '\n'.join([f'- {cp}: {self[cp]}' for cp in self.checkpoints]) # Timings
return f"""Timings object. Checkpoints:\n{cps}\n
Can...
- t.startRun: to get specific checkpoint's execution time
- t.plot(): to plot all checkpoints""" # Timings
_time = time.time # Timings
[docs]class Callbacks: # Callbacks
def __init__(self): # Callbacks
self._l: k1lib.Learner = None; self.cbsDict = {} # Callbacks
self._timings = Timings(); self.contexts = [[]] # Callbacks
@property # Callbacks
def timings(self) -> Timings: # Callbacks
"""Returns :class:`~k1lib.callbacks.callbacks.Timings` object""" # Callbacks
return self._timings # Callbacks
@property # Callbacks
def l(self) -> "k1lib.Learner": # Callbacks
""":class:`k1lib.Learner` object. Will be set automatically when
you set :attr:`k1lib.Learner.cbs` to this :class:`Callbacks`""" # Callbacks
return self._l # Callbacks
@l.setter # Callbacks
def l(self, learner): # Callbacks
self._l = learner # Callbacks
for cb in self.cbs: cb.l = learner # Callbacks
@property # Callbacks
def cbs(self) -> List[Callback]: # Callbacks
"""List of :class:`Callback`""" # Callbacks
return [*self.cbsDict.values()] # convenience method for looping over stuff # Callbacks
def _sort(self) -> "Callbacks": # Callbacks
self.cbsDict = OrderedDict(sorted(self.cbsDict.items(), key=(lambda o: o[1].order))); return self # Callbacks
[docs] def add(self, cb:Callback, name:str=None): # Callbacks
"""Adds a callback to the collection.
Example::
cbs = k1lib.Callbacks()
cbs.add(k1lib.Callback().withCheckpoint("startBatch", lambda self: print("here")))
If you just want to insert a simple callback with a single checkpoint, then
you can do something like::
cbs.add(["startBatch", lambda _: print("here")])""" # Callbacks
if isinstance(cb, (list, tuple)): # Callbacks
return self.add(Callback().withCheckpoint(cb[0], cb[1])) # Callbacks
if not isinstance(cb, Callback): raise RuntimeError("`cb` is not a callback!") # Callbacks
if cb in self.cbs: cb.l = self.l; cb.cbs = self; return self # Callbacks
cb.l = self.l; cb.cbs = self; name = name or cb.name # Callbacks
if name in self.cbsDict: # Callbacks
i = 0 # Callbacks
while f"{name}{i}" in self.cbsDict: i += 1 # Callbacks
name = f"{name}{i}" # Callbacks
cb.name = name; self.cbsDict[name] = cb; self._sort() # Callbacks
self._appendContext_append(cb); cb("attached"); return self # Callbacks
[docs] def __contains__(self, e:str) -> bool: # Callbacks
"""Whether a specific Callback name is in this :class:`Callback`.""" # Callbacks
return e in self.cbsDict # Callbacks
[docs] def remove(self, *names:List[str]): # Callbacks
"""Removes a callback from the collection.""" # Callbacks
for name in names: # Callbacks
if name not in self.cbsDict: return print(f"Callback `{name}` not found") # Callbacks
cb = self.cbsDict[name]; del self.cbsDict[name]; cb("detached") # Callbacks
self._sort(); return self # Callbacks
[docs] def removePrefix(self, prefix:str): # Callbacks
"""Removes any callback with the specified prefix""" # Callbacks
for cb in self.cbs: # Callbacks
if cb.name.startswith(prefix): self.remove(cb.name) # Callbacks
return self # Callbacks
[docs] def __call__(self, *checkpoints:List[str]) -> bool: # Callbacks
"""Calls a number of checkpoints one after another.
Returns True if any of the checkpoints return anything at all""" # Callbacks
self._checkpointGraph_call(checkpoints) # Callbacks
answer = False # Callbacks
for checkpoint in checkpoints: # Callbacks
beginTime = _time() # Callbacks
answer |= any([cb(checkpoint) for cb in self.cbs]) # Callbacks
self._timings[checkpoint] += _time() - beginTime # Callbacks
return answer # Callbacks
[docs] def __getitem__(self, idx:Union[int, str]) -> Callback: # Callbacks
"""Get specific cbs.
:param idx: if :class:`str`, then get the Callback with this specific name,
if :class:`int`, then get the Callback in that index.""" # Callbacks
return self.cbs[idx] if isinstance(idx, int) else self.cbsDict[idx] # Callbacks
[docs] def __iter__(self) -> Iterator[Callback]: # Callbacks
"""Iterates through all :class:`Callback`.""" # Callbacks
for cb in self.cbsDict.values(): yield cb # Callbacks
[docs] def __len__(self): # Callbacks
"""How many :class:`Callback` are there in total?""" # Callbacks
return len(self.cbsDict) # Callbacks
def __getattr__(self, attr): # Callbacks
if attr == "cbsDict": raise AttributeError(attr) # Callbacks
if attr in self.cbsDict: return self.cbsDict[attr] # Callbacks
else: raise AttributeError(attr) # Callbacks
def __getstate__(self): # Callbacks
state = dict(self.__dict__); state.pop("_l", None); return state # Callbacks
def __setstate__(self, state): # Callbacks
self.__dict__.update(state) # Callbacks
for cb in self.cbs: cb.cbs = self # Callbacks
def __dir__(self): # Callbacks
answer = list(super().__dir__()) # Callbacks
answer.extend(self.cbsDict.keys()) # Callbacks
return answer # Callbacks
def __repr__(self): # Callbacks
return "Callbacks:\n" + '\n'.join([f"- {cbName}" for cbName in self.cbsDict if not cbName.startswith("_")]) + """\n
Use...
- cbs.add(cb[, name]): to add a callback with a name
- cbs("startRun"): to trigger a specific checkpoint, this case "startRun"
- cbs.Loss: to get a specific callback by name, this case "Loss"
- cbs[i]: to get specific callback by index
- cbs.timings: to get callback execution times
- cbs.checkpointGraph(): to graph checkpoint calling orders
- cbs.context(): context manager that will detach all Callbacks attached inside the context
- cbs.suspend("Loss", "Cuda"): context manager to temporarily prevent triggering checkpoints""" # Callbacks
[docs] def withBasics(self): # Callbacks
"""Adds a bunch of very basic Callbacks that's needed for everything. Also
includes Callbacks that are not necessary, but don't slow things down""" # Callbacks
self.add(Cbs.CoreNormal()).add(Cbs.Profiler()).add(Cbs.Recorder()) # Callbacks
self.add(Cbs.ProgressBar()).add(Cbs.Loss()).add(Cbs.Accuracy()).add(Cbs.DontTrainValid()) # Callbacks
return self.add(Cbs.CancelOnExplosion()).add(Cbs.ParamFinder()) # Callbacks
[docs] def withQOL(self): # Callbacks
"""Adds quality of life Callbacks.""" # Callbacks
return self # Callbacks
[docs] def withAdvanced(self): # Callbacks
"""Adds advanced Callbacks that do fancy stuff, but may slow things
down if not configured specifically.""" # Callbacks
return self.add(Cbs.HookModule().withMeanRecorder().withStdRecorder()).add(Cbs.HookParam()) # Callbacks
@k1lib.patch(Callbacks) # Callbacks
def _resolveDependencies(self): # _resolveDependencies
for cb in self.cbs: # _resolveDependencies
cb._dependents:Set[Callback] = set() # _resolveDependencies
cb.dependsOn = set(cb.dependsOn) # _resolveDependencies
for cb in self.cbs: # _resolveDependencies
for cb2 in self.cbs: # _resolveDependencies
if cb2.__class__.__name__ in cb.dependsOn: # _resolveDependencies
cb2._dependents.add(cb) # _resolveDependencies
class SuspendContext: # SuspendContext
def __init__(self, cbs:Callbacks, cbsNames:List[str], cbsClasses:List[str]): # SuspendContext
self.cbs = cbs; self.cbsNames = cbsNames; self.cbsClasses = cbsClasses # SuspendContext
self.cbs.suspendStack = getattr(self.cbs, "suspendStack", []) # SuspendContext
def __enter__(self): # SuspendContext
cbsClasses = set(self.cbsClasses); cbsNames = set(self.cbsNames) # SuspendContext
self._resolveDependencies() # SuspendContext
def explore(cb:Callback): # SuspendContext
for dept in cb._dependents: # SuspendContext
cbsClasses.add(dept.__class__.__name__); explore(dept) # SuspendContext
[explore(cb) for cb in self.cbs if cb.__class__.__name__ in cbsClasses or cb.name in cbsNames] # SuspendContext
stackFrame = {cb:cb.suspended for cb in self.cbs if cb.__class__.__name__ in cbsClasses or cb.name in cbsNames} # SuspendContext
for cb in stackFrame: cb.suspend(); cb.suspended = True # SuspendContext
self.suspendStack.append(stackFrame) # SuspendContext
def __exit__(self, *ignored): # SuspendContext
for cb, oldValue in self.suspendStack.pop().items(): # SuspendContext
cb.suspended = oldValue; cb.restore() # SuspendContext
def __getattr__(self, attr): return getattr(self.cbs, attr) # SuspendContext
@k1lib.patch(Callbacks) # SuspendContext
def suspend(self, *cbNames:List[str]) -> ContextManager: # suspend
"""Creates suspension context for specified Callbacks. Matches callbacks with
their name. Works like this::
cbs = k1lib.Callbacks().add(CbA()).add(CbB()).add(CbC())
with cbs.suspend("CbA", "CbC"):
pass # inside here, only CbB will be active, and its checkpoints executed
# CbA, CbB and CbC are all active
.. seealso:: :meth:`suspendClasses`""" # suspend
return SuspendContext(self, cbNames, []) # suspend
@k1lib.patch(Callbacks) # suspend
def suspendClasses(self, *classNames:List[str]) -> ContextManager: # suspendClasses
"""Like :meth:`suspend`, but matches callbacks' class names to the given list,
instead of matching names. Meaning::
cbs.k1lib.Callbacks().add(Cbs.Loss()).add(Cbs.Loss())
# cbs now has 2 callbacks "Loss" and "Loss0"
with cbs.suspendClasses("Loss"):
pass # now both of them are suspended""" # suspendClasses
return SuspendContext(self, [], classNames) # suspendClasses
@k1lib.patch(Callbacks) # suspendClasses
def suspendEval(self, more:List[str]=[], less:List[str]=[]) -> ContextManager: # suspendEval
"""Same as :meth:`suspendClasses`, but suspend some default classes typical
used for evaluation callbacks. Just convenience method really. Currently includes:
- HookModule, HookParam, ProgressBar
- ParamScheduler, Loss, Accuracy, Autosave
- ConfusionMatrix
:param more: include more classes to be suspended
:param less: exclude classes supposed to be suspended by default""" # suspendEval
classes = {"HookModule", "HookParam", "ProgressBar", "ParamScheduler", "Loss", "Accuracy", "Autosave", "ConfusionMatrix"} # suspendEval
classes.update(more); classes -= set(less) # suspendEval
return self.suspendClasses(*classes) # suspendEval
class AppendContext: # AppendContext
def __init__(self, cbs:Callbacks, initCbs:List[Callback]=[]): # AppendContext
self.cbs = cbs; self.initCbs = initCbs # AppendContext
def __enter__(self): # AppendContext
self.cbs.contexts.append([]) # AppendContext
for cb in self.initCbs: self.cbs.add(cb) # AppendContext
return self.cbs # AppendContext
def __exit__(self, *ignored): # AppendContext
[cb.detach() for cb in self.cbs.contexts.pop()] # AppendContext
@k1lib.patch(Callbacks) # AppendContext
def _appendContext_append(self, cb): # _appendContext_append
self.contexts[-1].append(cb) # _appendContext_append
@k1lib.patch(Callbacks) # _appendContext_append
def context(self, *initCbs:List[Callback]) -> ContextManager: # context
"""Add context.
Works like this::
cbs = k1lib.Callbacks().add(CbA())
# CbA is available
with cbs.context(CbE(), CbF()):
cbs.add(CbB())
# CbA, CbB, CbE and CbF available
cbs.add(CbC())
# all 5 are available
# only CbA is available
For maximum shortness, you can do this::
with k1lib.Callbacks().context(CbA()) as cbs:
# Cba is available
""" # context
return AppendContext(self, initCbs) # context
@k1lib.patch(Callbacks) # context
def _checkpointGraph_call(self, checkpoints:List[str]): # _checkpointGraph_call
if not hasattr(self, "_checkpointGraphDict"): # _checkpointGraph_call
self._checkpointGraphDict = k1lib.Object().withAutoDeclare(lambda: k1lib.Object().withAutoDeclare(lambda: 0)) # _checkpointGraph_call
self._lastCheckpoint = "<root>" # _checkpointGraph_call
for cp in checkpoints: # _checkpointGraph_call
self._checkpointGraphDict[self._lastCheckpoint][cp] += 1 # _checkpointGraph_call
self._lastCheckpoint = cp # _checkpointGraph_call
@k1lib.patch(Callbacks) # _checkpointGraph_call
def checkpointGraph(self, highlightCb:Union[str, Callback]=None): # checkpointGraph
"""Graphs what checkpoints follows what checkpoints. Has to run at least once
first. Requires graphviz package though. Example::
cbs = Callbacks()
cbs("a", "b", "c", "d", "b")
cbs.checkpointGraph() # returns graph object. Will display image if using notebooks
.. image:: ../images/checkpointGraph.png
:param highlightCb: if available, will highlight the checkpoints the callback
uses. Can be name/class-name/class/self of callback.""" # checkpointGraph
g = k1lib.digraph(); s = set() # checkpointGraph
for cp1, cp1o in self._checkpointGraphDict.state.items(): # checkpointGraph
for cp2, v in cp1o.state.items(): # checkpointGraph
g.edge(cp1, cp2, label=f" {v} "); s.add(cp2) # checkpointGraph
if highlightCb != None: # checkpointGraph
_cb = None # checkpointGraph
if isinstance(highlightCb, Callback): _cb = highlightCb # checkpointGraph
elif isinstance(highlightCb, type) and issubclass(highlightCb, Callback): # find cb that has the same class # checkpointGraph
for cbo in self.cbs: # checkpointGraph
if isinstance(cbo, highlightCb): _cb = cbo; break # checkpointGraph
if _cb is None: raise AttributeError(f"Can't find any Callback inside this Callbacks which is of type `{cb.__name__}`") # checkpointGraph
elif isinstance(highlightCb, str): # checkpointGraph
for cbName, cbo in self.cbsDict.items(): # checkpointGraph
if cbName == highlightCb: _cb = cbo; break # checkpointGraph
if type(cbo).name == highlightCb: _cb = cbo; break # checkpointGraph
if _cb is None: raise AttributeError(f"Can't find any Callback inside this Callbacks with name or class `{cb}`") # checkpointGraph
else: raise AttributeError(f"Don't understand {cb}") # checkpointGraph
print(f"Highlighting callback `{_cb.name}`, of type `{type(_cb)}`") # checkpointGraph
for cp in s: # checkpointGraph
if hasattr(_cb, cp): g.node(cp, color="red") # checkpointGraph
return g # checkpointGraph