# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib; plt = k1lib.dep.plt
from functools import partial
from typing import List, Tuple, Callable, Union
try: import torch; import torch.nn as nn; hasTorch = True
except:
torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["HookParam"]
class ParamData(k1lib.Object): # ParamData
def __init__(self): # ParamData
super().__init__() # ParamData
self.means = []; self.stds = [] # ParamData
self.mins = []; self.maxs = [] # ParamData
def update(self, torchParam:nn.Parameter): # ParamData
self.means.append(torchParam.mean().item()) # ParamData
self.stds.append(torchParam.std().item()) # ParamData
self.mins.append(torchParam.min().item()) # ParamData
self.maxs.append(torchParam.max().item()) # ParamData
def __len__(self): return len(self.means) # ParamData
def __repr__(self): # ParamData
return f"""Param's saved data. Use...
- d.means: to get list of means
- d.stds: to get list of means
- d.mins: to get list of mins
- d.maxs: to get list of maxs""" # ParamData
class Param: # Param
def __init__(self, name:str, torchParam:nn.Parameter): # Param
self.name = name # Param
self.torchParam = torchParam # Param
self.data = ParamData() # Param
self.every = k1lib.Every(3) # Param
def update(self): # Param
if self.every(): self.data.update(self.torchParam.detach()) # Param
def __repr__(self): # Param
return f"""Param `{self.name}`. Use...
- p.torchParam: to get actual underlying parameter
- p.data: to get data stored
- cb.plot(): to quickly look at everything""" # Param
[docs]@k1lib.patch(Cbs) # Param
class HookParam(Callback): # HookParam
"""Records means and stds of all parameters""" # HookParam
def __init__(self): # HookParam
"" # HookParam
super().__init__(); self.params:List[Param] = [] # HookParam
def __getitem__(self, idx:Union[int, slice]): # HookParam
if type(idx) == int: return self.params[idx] # HookParam
answer = HookParam(); answer.params = self.params[idx]; return answer # HookParam
def __len__(self): return len(self.params) # HookParam
def _selected(self, paramName:str): # HookParam
splits = paramName.split(".") # HookParam
try: # HookParam
mS = self.l.selector # HookParam
for split in splits[:-1]: mS = mS[split] # HookParam
return "HookParam" in mS and hasattr(mS, splits[-1]) # HookParam
except KeyError: return False # HookParam
def startRun(self): # HookParam
if len(self) == 0: # set things up first time only # HookParam
self.params = [Param(k, v) for k, v in self.l.model.named_parameters() if self._selected(k)] # HookParam
def startBatch(self): [param.update() for param in self.params] # HookParam
[docs] def css(self, css:str): # HookParam
"""Creates a new HookParam object with selected modules. May be useful
for displaying a subset of the recorded data""" # HookParam
oldSelector = self.l.selector; answer = HookParam() # HookParam
self.l.selector = k1lib.selector.select(self.l.model, css) # HookParam
answer.params = [param for param in self.params if self._selected(param.name)] # HookParam
self.l.selector = oldSelector; return answer # HookParam
def __repr__(self): # HookParam
s = f", {len(self[0].data)} means and stds each" if len(self) > 0 else "" # HookParam
names = "\n".join([f" {i}. {p.name}" for i, p in enumerate(self)]) # HookParam
return f"""{super()._reprHead}: {len(self)} params{s}:\n{names}\n
Use...
- p.plot(): to quickly look at everything
- p[i]: to view a single param
- p[a:b]: to get a new HookParam with selected params
- p.css("..."): to select a specific subset of modules only
{super()._reprCan}""" # HookParam
def plotF(params:Union[HookParam, Param, List[Param]], rangeSlice:slice): # plotF
if type(params) == Param: params = [params] # plotF
fields = params[0].data.state.keys(); step = rangeSlice.step or 1 # plotF
fig, axes = plt.subplots(2, 2, figsize=(10, 6), dpi=100) # plotF
axes = axes.flatten() # plotF
for field, ax in zip(fields, axes): # plotF
for param in params: # plotF
fieldData = param.data[field] # plotF
r = k1lib.Range(len(fieldData))[rangeSlice] # plotF
ax.plot(r.range_[::step], fieldData[r.slice_][::step]) # plotF
ax.set_title(field.capitalize()) # plotF
plt.figlegend([p.name for p in params], loc='right') # plotF
@k1lib.patch(HookParam) # plotF
@k1lib.patch(Param) # plotF
def plot(self): return k1lib.viz.SliceablePlot(partial(plotF, self)) # plot