Source code for k1lib.callbacks.profilers.memory

# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from k1lib.callbacks import Callback, Cbs; from k1lib import fmt, cli
import k1lib, torch, math, gc, numpy as np; from functools import partial
plt = k1lib.dep.plt
def allocated() -> int: return torch.cuda.memory_allocated()                     # allocated
class MemoryData: # handles hooks of 1 nn.Module                                 # MemoryData
    def __init__(self, mProfiler, mS:k1lib.selector.ModuleSelector):             # MemoryData
        self.mProfiler = mProfiler; self.mS = mS                                 # MemoryData
        self.handles = k1lib.Object.fromDict({"fp":0,"f":0,"b":0})               # MemoryData
        self.values = k1lib.Object.fromDict({"fp":0,"f":0,"b":0})                # MemoryData
        self.hook()                                                              # MemoryData
    def hook(self):                                                              # MemoryData
        mS = self.mS; mP = self.mProfiler                                        # MemoryData
        def hk(v, m, i, o=None): # v: type of hook                               # MemoryData
            gc.collect(); value = allocated() - mP.startMemory; self.values[v] += value # MemoryData
            if v == "f" or v == "b":                                             # MemoryData
                mP.stepData.append([value, 0, mS.idx])                           # MemoryData
                # for the dashed line separating forward and backward            # MemoryData
                if v == "b" and mP.startBackwardPoint is None: mP.startBackwardPoint = len(mP.stepData) # MemoryData
        self.handles.fp = mS.nn.register_forward_pre_hook  (partial(hk, "fp"))   # MemoryData
        self.handles.f  = mS.nn.register_forward_hook      (partial(hk, "f"))    # MemoryData
        self.handles.b  = mS.nn.register_full_backward_hook(partial(hk, "b"))    # MemoryData
    def unhook(self):                                                            # MemoryData
        self.handles.fp.remove(); self.handles.f.remove(); self.handles.b.remove() # MemoryData
    def __getstate__(self):                                                      # MemoryData
        answer = dict(self.__dict__)                                             # MemoryData
        del answer["mS"]; del answer["mProfiler"]; return answer                 # MemoryData
    def __setstate__(self, state): self.__dict__.update(dict(state))             # MemoryData
    def __str__(self):                                                           # MemoryData
        fp = f"fp({fmt.size(self.values.fp)})".ljust(14)                         # MemoryData
        f  =  f"f({fmt.size(self.values.f)})" .ljust(13)                         # MemoryData
        b  =  f"b({fmt.size(self.values.b)})" .ljust(13)                         # MemoryData
        delta = f"delta({fmt.size(self.values.f - self.values.fp)})".ljust(17)   # MemoryData
        return f"{b} {delta} {fp} {f}"                                           # MemoryData
[docs]class MemoryProfiler(Callback): # MemoryProfiler """Expected to be run only once only. If a new report for a new network architecture is required, then create a new one. Example:: l = k1lib.Learner.sample() l.cbs.add(Cbs.Profiler()) # views graph and table l.Profiler.memory # views graph and table highlighted l.Profiler.memory.css("Linear")""" # MemoryProfiler def startRun(self): # MemoryProfiler if not hasattr(self, "selector"): # MemoryProfiler self.selector = self.l.model.select("") # MemoryProfiler for mS in self.selector.modules(): mS.data = MemoryData(self, mS) # MemoryProfiler self.selector.displayF = lambda mS: (fmt.txt.red if "_memProf_" in mS else fmt.txt.identity)(mS.data) # MemoryProfiler self.startMemory = allocated() # MemoryProfiler self.stepData:List[Tuple[int, bool, int]] = [] # (bytes, css selected, mS.idx) # MemoryProfiler self.startBackwardPoint = None # MemoryProfiler def startStep(self): return True # MemoryProfiler def endRun(self): self._updateLinState() # MemoryProfiler def _run(self): # MemoryProfiler """Runs everything""" # MemoryProfiler with self.cbs.context(), self.cbs.suspendEval(), self.l.model.deviceContext(): # MemoryProfiler self.cbs.add(Cbs.Cuda()); self.l.run(1, 1) # MemoryProfiler for m in self.selector.modules(): m.data.unhook() # MemoryProfiler def _updateLinState(self): # MemoryProfiler """Change linState, which is the graph's highlight""" # MemoryProfiler @self.selector.apply # MemoryProfiler def applyF(mS): # MemoryProfiler for step in self.stepData: # MemoryProfiler if step[2] == mS.idx: step[1] = "_memProf_" in mS # MemoryProfiler
[docs] def css(self, css:str): # MemoryProfiler """Selects a small part of the network to highlight. See also: :mod:`k1lib.selector`.""" # MemoryProfiler self.selector.parse(k1lib.selector.preprocess(css, "_memProf_")) # MemoryProfiler self._updateLinState(); print(self.__repr__()) # MemoryProfiler self.selector.clearProps(); self._updateLinState() # MemoryProfiler
@k1lib.patch(MemoryProfiler) # MemoryProfiler def __repr__(self): # __repr__ plt.figure(dpi=120); plt.grid(True); plt.xlabel("Time") # __repr__ l, s, _ = self.stepData | cli.transpose() | cli.deref() # __repr__ label, l = fmt.sizeOf(l); plt.ylabel(label) # __repr__ k1lib.viz.plotSegments(range(len(l)), l, s) # __repr__ plt.axvline(self.startBackwardPoint, linestyle="--") # __repr__ ax = plt.gca(); ax.text(0.05, 0.05, "forward", transform=ax.transAxes) # __repr__ ax.text(0.95, 0.05, "backward", ha="right", transform=ax.transAxes); plt.show() # __repr__ c = self.selector.__repr__(intro=False).split("\n") | cli.tab() | cli.join("\n") # __repr__ return f"""MemoryProfiler (params: {fmt.item(self.l.model.nParams)}):\n{c} Can... - mp.css("..."): highlights a particular part of the network - mp.selector: to get internal k1lib.selector.ModuleSelector object""" # __repr__