# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from k1lib.callbacks import Callback, Cbs
import k1lib, numpy as np; from torch import nn
from k1lib import cli
_spacing = lambda s: f"{s}   "; # inserted at end of everything, if that element existed
_lcomp = 14; _lp1 = 8; _lp2 = 15; _lp3 = 14
class ComputationData:                                                           # ComputationData
    def __init__(self, cProfiler, mS:k1lib.selector.ModuleSelector):             # ComputationData
        self.cProfiler = cProfiler; self.mS = mS; self.flop = 0                  # ComputationData
        self.handle = None; self.hook()                                          # ComputationData
        self.flops = 0; self.tS = None # corresponding time selector             # ComputationData
    def hook(self):                                                              # ComputationData
        def hk(m, i, o):                                                         # ComputationData
            i = k1lib.squeeze(i)                                                 # ComputationData
            if isinstance(m, nn.Linear): self.flop += i.numel() * m.out_features # ComputationData
            elif isinstance(m, nn.Conv2d):                                       # ComputationData
                self.flop += m.out_channels * i.shape.numel() * np.prod(m.kernel_size) # ComputationData
            elif isinstance(m, (nn.LeakyReLU, nn.ReLU, nn.Sigmoid)):             # ComputationData
                self.flop += i.numel()                                           # ComputationData
        self.handle = self.mS.nn.register_forward_hook(hk)                       # ComputationData
    def unhook(self):                                                            # ComputationData
        self.cProfiler.totalFlop += self.flop; self.handle.remove()              # ComputationData
    def __getstate__(self):                                                      # ComputationData
        answer = dict(self.__dict__)                                             # ComputationData
        del answer["mS"]; del answer["cProfiler"]; return answer                 # ComputationData
    def __setstate__(self, state): self.__dict__.update(dict(state))             # ComputationData
    def __str__(self):                                                           # ComputationData
        if self.flop <= 0: return ""                                             # ComputationData
        a = _spacing(f"{k1lib.fmt.comp(self.flop)}".ljust(_lcomp))               # ComputationData
        b = _spacing(f"{round(100 * self.flop / self.cProfiler.totalFlop)}%".rjust(_lp1)) # ComputationData
        c = ""                                                                   # ComputationData
        if self.cProfiler._tpAvailable:                                          # ComputationData
            self.flops = self.flop / self.tS.data.time                           # ComputationData
            c = _spacing(f"{k1lib.fmt.compRate(self.flops)}".ljust(_lp2))        # ComputationData
        d = ""                                                                   # ComputationData
        if self.cProfiler.selected:                                              # ComputationData
            if "_compProf_" in self.mS:                                          # ComputationData
                d = f"{round(100 * self.flop / self.cProfiler.selectedTotalFlop)}%" # ComputationData
            d = _spacing(d.rjust(_lp3))                                          # ComputationData
        return f"{a}{b}{c}{d}"                                                   # ComputationData
[docs]
class ComputationProfiler(Callback):                                             # ComputationProfiler
    """Profiles computation. Only provide reports on well known
layers only, and thus can't really be universal. Example::
    l = k1lib.Learner.sample()
    l.cbs.add(Cbs.Profiler())
    # views table
    l.Profiler.computation
    # views table highlighted
    l.Profiler.computation.css("#lin1 > #lin")
"""                                                                              # ComputationProfiler
    def __init__(self, profiler:"Profiler"):                                     # ComputationProfiler
        super().__init__(); self.profiler = profiler                             # ComputationProfiler
    def startRun(self):                                                          # ComputationProfiler
        if not hasattr(self, "selector"): # if no selectors found                # ComputationProfiler
            self.selector = self.l.model.select("")                              # ComputationProfiler
        for m in self.selector.modules(): m.data = ComputationData(self, m)      # ComputationProfiler
        self.selector.displayF = lambda m: (k1lib.fmt.txt.red if "_compProf_" in m else k1lib.fmt.txt.identity)(m.data) # ComputationProfiler
        self.totalFlop = 0; self.selectedTotalFlop = None                        # ComputationProfiler
    @property                                                                    # ComputationProfiler
    def selected(self): return self.selectedTotalFlop != None                    # ComputationProfiler
    @property                                                                    # ComputationProfiler
    def _tpAvailable(self) -> bool:                                              # ComputationProfiler
        """Whether TimeProfiler's results are available"""                       # ComputationProfiler
        try: self.profiler._time(); return True                                  # ComputationProfiler
        except Exception as e: return False                                      # ComputationProfiler
    def startStep(self): return True                                             # ComputationProfiler
    def _run(self):                                                              # ComputationProfiler
        """Runs everything"""                                                    # ComputationProfiler
        with self.cbs.context(), self.cbs.suspendEval():                         # ComputationProfiler
            self.cbs.add(Cbs.Cpu()); self.l.run(1, 1)                            # ComputationProfiler
        for m in self.selector.modules(): m.data.unhook()                        # ComputationProfiler
    def detached(self): # time profiler integration, so that flops can be displayed # ComputationProfiler
        if self._tpAvailable:                                                    # ComputationProfiler
            for cS, tS in zip(self.selector.modules(), self.profiler.time.selector.modules()): # ComputationProfiler
                cS.data.tS = tS # injecting dependency                           # ComputationProfiler
[docs]
    def css(self, css:str):                                                      # ComputationProfiler
        """Selects a small part of the network to highlight. See also: :mod:`k1lib.selector`.""" # ComputationProfiler
        self.selector.parse(k1lib.selector.preprocess(css, "_compProf_"))        # ComputationProfiler
        self.selectedTotalFlop = 0                                               # ComputationProfiler
        for m in self.selector.modules():                                        # ComputationProfiler
            if "_compProf_" in m:                                                # ComputationProfiler
                self.selectedTotalFlop += m.data.flop                            # ComputationProfiler
        print(self.__repr__())                                                   # ComputationProfiler
        self.selector.clearProps(); self.selectedTotalFlop = None                # ComputationProfiler 
    def __repr__(self):                                                          # ComputationProfiler
        header = _spacing("computation".ljust(_lcomp))                           # ComputationProfiler
        header += _spacing("% total".rjust(_lp1))                                # ComputationProfiler
        header += _spacing("rate".ljust(_lp2)) if self._tpAvailable else ""      # ComputationProfiler
        header += _spacing("% selected".rjust(_lp3)) if self.selected else ""    # ComputationProfiler
        footer = _spacing(f"{k1lib.fmt.comp(self.totalFlop)}".ljust(_lcomp))     # ComputationProfiler
        footer += _spacing("".rjust(_lp1))                                       # ComputationProfiler
        footer += _spacing("".ljust(_lp2)) if self._tpAvailable else ""          # ComputationProfiler
        footer += _spacing(f"{k1lib.fmt.comp(self.selectedTotalFlop)}".rjust(_lp3)) if self.selected else '' # ComputationProfiler
        footer = ("Total", footer)                                               # ComputationProfiler
        c = self.selector.__repr__(intro=False, header=header, footer=footer).split("\n") | cli.tab() | cli.join("\n") # ComputationProfiler
        return f"""ComputationProfiler:\n{c}
The "rate" column will appear if integration with Profiler.time is
possible, showing actual ops/s
Can...
- cp.css("..."): highlights a particular part of the network
- cp.selector: to get internal k1lib.ModuleSelector object"""                    # ComputationProfiler