# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from k1lib.callbacks import Callback, Cbs
import k1lib; from k1lib import cli
_li = 30
class IOData: # IOData
def __init__(self, ioProfiler, mS:k1lib.selector.ModuleSelector): # IOData
self.ioProfiler = ioProfiler; self.mS = mS # IOData
self.iS = None; self.oS = None # IOData
self.handle = None; self.hook() # IOData
def hook(self): # IOData
def hk(m, i, o): # IOData
self.iS = list(k1lib.squeeze(i, True).shape) # IOData
self.oS = list(k1lib.squeeze(o, True).shape) # IOData
self.handle = self.mS.nn.register_forward_hook(hk) # IOData
def unhook(self): self.handle.remove() # IOData
def __getstate__(self): # IOData
answer = dict(self.__dict__) # IOData
del answer["mS"]; del answer["ioProfiler"]; return answer # IOData
def __setstate__(self, state): self.__dict__.update(dict(state)) # IOData
def __str__(self): # IOData
a = f"{self.iS}".ljust(_li); b = f"{self.oS}".ljust(_li) # IOData
return f"{a}{b}" # IOData
[docs]class IOProfiler(Callback): # IOProfiler
"""Gets input and output shapes of each layer.
Example::
l = k1lib.Learner.sample()
l.cbs.add(Cbs.Profiler())
# views table
l.Profiler.io
# views table highlighted
l.Profiler.io.css("#lin1")
""" # IOProfiler
def startRun(self): # IOProfiler
if not hasattr(self, "selector"): # if no selectors found # IOProfiler
self.selector = self.l.model.select("") # IOProfiler
for m in self.selector.modules(): m.data = IOData(self, m) # IOProfiler
self.selector.displayF = lambda m: (k1lib.fmt.txt.red if "_ioProf_" in m else k1lib.fmt.txt.identity)(m.data) # IOProfiler
def startStep(self): return True # IOProfiler
def _run(self): # IOProfiler
"""Runs everything""" # IOProfiler
with self.cbs.suspendEval(): self.l.run(1, 1) # IOProfiler
for m in self.selector.modules(): m.data.unhook() # IOProfiler
[docs] def css(self, css:str): # IOProfiler
"""Selects a small part of the network to highlight. See also: :mod:`k1lib.selector`.""" # IOProfiler
self.selector.parse(k1lib.selector.preprocess(css, "_ioProf_")) # IOProfiler
print(self.__repr__()); self.selector.clearProps() # IOProfiler
def __repr__(self): # IOProfiler
header = "input shape".ljust(_li) + "output shape".ljust(_li) # IOProfiler
c = self.selector.__repr__(intro=False, header=header).split("\n") | cli.tab() | cli.join("\n") # IOProfiler
return f"""IOProfiler:\n{c}
Can...
- iop.css("..."): highlights a particular part of the network
- iop.selector: to get internal k1lib.ModuleSelector object""" # IOProfiler