Source code for k1lib.callbacks.hookParam

# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib; plt = k1lib.dep.plt
from functools import partial
from typing import List, Tuple, Callable, Union
try: import torch; import torch.nn as nn; hasTorch = True
except:
    torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
    nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["HookParam"]
class ParamData(k1lib.Object):                                                   # ParamData
    def __init__(self):                                                          # ParamData
        super().__init__()                                                       # ParamData
        self.means = []; self.stds = []                                          # ParamData
        self.mins = []; self.maxs = []                                           # ParamData
    def update(self, torchParam:nn.Parameter):                                   # ParamData
        self.means.append(torchParam.mean().item())                              # ParamData
        self.stds.append(torchParam.std().item())                                # ParamData
        self.mins.append(torchParam.min().item())                                # ParamData
        self.maxs.append(torchParam.max().item())                                # ParamData
    def __len__(self): return len(self.means)                                    # ParamData
    def __repr__(self):                                                          # ParamData
        return f"""Param's saved data. Use...
- d.means: to get list of means
- d.stds: to get list of means
- d.mins: to get list of mins
- d.maxs: to get list of maxs"""                                                 # ParamData
class Param:                                                                     # Param
    def __init__(self, name:str, torchParam:nn.Parameter):                       # Param
        self.name = name                                                         # Param
        self.torchParam = torchParam                                             # Param
        self.data = ParamData()                                                  # Param
        self.every = k1lib.Every(3)                                              # Param
    def update(self):                                                            # Param
        if self.every(): self.data.update(self.torchParam.detach())              # Param
    def __repr__(self):                                                          # Param
        return f"""Param `{self.name}`. Use...
- p.torchParam: to get actual underlying parameter
- p.data: to get data stored
- cb.plot(): to quickly look at everything"""                                    # Param
[docs]@k1lib.patch(Cbs) # Param class HookParam(Callback): # HookParam """Records means and stds of all parameters""" # HookParam def __init__(self): # HookParam "" # HookParam super().__init__(); self.params:List[Param] = [] # HookParam def __getitem__(self, idx:Union[int, slice]): # HookParam if type(idx) == int: return self.params[idx] # HookParam answer = HookParam(); answer.params = self.params[idx]; return answer # HookParam def __len__(self): return len(self.params) # HookParam def _selected(self, paramName:str): # HookParam splits = paramName.split(".") # HookParam try: # HookParam mS = self.l.selector # HookParam for split in splits[:-1]: mS = mS[split] # HookParam return "HookParam" in mS and hasattr(mS, splits[-1]) # HookParam except KeyError: return False # HookParam def startRun(self): # HookParam if len(self) == 0: # set things up first time only # HookParam self.params = [Param(k, v) for k, v in self.l.model.named_parameters() if self._selected(k)] # HookParam def startBatch(self): [param.update() for param in self.params] # HookParam
[docs] def css(self, css:str): # HookParam """Creates a new HookParam object with selected modules. May be useful for displaying a subset of the recorded data""" # HookParam oldSelector = self.l.selector; answer = HookParam() # HookParam self.l.selector = k1lib.selector.select(self.l.model, css) # HookParam answer.params = [param for param in self.params if self._selected(param.name)] # HookParam self.l.selector = oldSelector; return answer # HookParam
def __repr__(self): # HookParam s = f", {len(self[0].data)} means and stds each" if len(self) > 0 else "" # HookParam names = "\n".join([f" {i}. {p.name}" for i, p in enumerate(self)]) # HookParam return f"""{super()._reprHead}: {len(self)} params{s}:\n{names}\n Use... - p.plot(): to quickly look at everything - p[i]: to view a single param - p[a:b]: to get a new HookParam with selected params - p.css("..."): to select a specific subset of modules only {super()._reprCan}""" # HookParam
def plotF(params:Union[HookParam, Param, List[Param]], rangeSlice:slice): # plotF if type(params) == Param: params = [params] # plotF fields = params[0].data.state.keys(); step = rangeSlice.step or 1 # plotF fig, axes = plt.subplots(2, 2, figsize=(10, 6), dpi=100) # plotF axes = axes.flatten() # plotF for field, ax in zip(fields, axes): # plotF for param in params: # plotF fieldData = param.data[field] # plotF r = k1lib.Range(len(fieldData))[rangeSlice] # plotF ax.plot(r.range_[::step], fieldData[r.slice_][::step]) # plotF ax.set_title(field.capitalize()) # plotF plt.figlegend([p.name for p in params], loc='right') # plotF @k1lib.patch(HookParam) # plotF @k1lib.patch(Param) # plotF def plot(self): return k1lib.viz.SliceablePlot(partial(plotF, self)) # plot