# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from k1lib.callbacks import Callback, Cbs
import k1lib; from k1lib import cli
_li = 30
class IOData:                                                                    # IOData
    def __init__(self, ioProfiler, mS:k1lib.selector.ModuleSelector):            # IOData
        self.ioProfiler = ioProfiler; self.mS = mS                               # IOData
        self.iS = None; self.oS = None                                           # IOData
        self.handle = None; self.hook()                                          # IOData
    def hook(self):                                                              # IOData
        def hk(m, i, o):                                                         # IOData
            self.iS = list(k1lib.squeeze(i, True).shape)                         # IOData
            self.oS = list(k1lib.squeeze(o, True).shape)                         # IOData
        self.handle = self.mS.nn.register_forward_hook(hk)                       # IOData
    def unhook(self): self.handle.remove()                                       # IOData
    def __getstate__(self):                                                      # IOData
        answer = dict(self.__dict__)                                             # IOData
        del answer["mS"]; del answer["ioProfiler"]; return answer                # IOData
    def __setstate__(self, state): self.__dict__.update(dict(state))             # IOData
    def __str__(self):                                                           # IOData
        a = f"{self.iS}".ljust(_li); b = f"{self.oS}".ljust(_li)                 # IOData
        return f"{a}{b}"                                                         # IOData
[docs]
class IOProfiler(Callback):                                                      # IOProfiler
    """Gets input and output shapes of each layer.
Example::
    l = k1lib.Learner.sample()
    l.cbs.add(Cbs.Profiler())
    # views table
    l.Profiler.io
    # views table highlighted
    l.Profiler.io.css("#lin1")
"""                                                                              # IOProfiler
    def startRun(self):                                                          # IOProfiler
        if not hasattr(self, "selector"): # if no selectors found                # IOProfiler
            self.selector = self.l.model.select("")                              # IOProfiler
        for m in self.selector.modules(): m.data = IOData(self, m)               # IOProfiler
        self.selector.displayF = lambda m: (k1lib.fmt.txt.red if "_ioProf_" in m else k1lib.fmt.txt.identity)(m.data) # IOProfiler
    def startStep(self): return True                                             # IOProfiler
    def _run(self):                                                              # IOProfiler
        """Runs everything"""                                                    # IOProfiler
        with self.cbs.suspendEval(): self.l.run(1, 1)                            # IOProfiler
        for m in self.selector.modules(): m.data.unhook()                        # IOProfiler
[docs]
    def css(self, css:str):                                                      # IOProfiler
        """Selects a small part of the network to highlight. See also: :mod:`k1lib.selector`.""" # IOProfiler
        self.selector.parse(k1lib.selector.preprocess(css, "_ioProf_"))          # IOProfiler
        print(self.__repr__()); self.selector.clearProps()                       # IOProfiler 
    def __repr__(self):                                                          # IOProfiler
        header = "input shape".ljust(_li) + "output shape".ljust(_li)            # IOProfiler
        c = self.selector.__repr__(intro=False, header=header).split("\n") | cli.tab() | cli.join("\n") # IOProfiler
        return f"""IOProfiler:\n{c}
Can...
- iop.css("..."): highlights a particular part of the network
- iop.selector: to get internal k1lib.ModuleSelector object"""                   # IOProfiler