Source code for k1lib.callbacks.landscape

# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, numpy as np, time
plt = k1lib.dep.plt
from typing import Callable
import k1lib.cli as cli
try: import torch; hasTorch = True
except: hasTorch = False
__all__ = ["Landscape"]
spacing = 0.35 # orders of magnitude
offset = -2 # orders of magnitude shift
res = 20 # resolution
scales = 10**(np.array(range(8))*spacing + offset)
scales = [round(scale, 3) for scale in scales]
scales
F = Callable[["k1lib.Learner"], float]
[docs]@k1lib.patch(Cbs) class Landscape(Callback): # Landscape " " # Landscape
[docs] def __init__(self, propertyF:F, name:str=None): # Landscape """Plots the landscape of the network. :param propertyF: a function that takes in :class:`k1lib.Learner` and outputs the desired float property .. warning:: Remember to detach anything you get from :class:`k1lib.Learner` in your function, or else you're gonna cause a huge memory leak. """ # Landscape super().__init__(); self.propertyF = propertyF; self.suspended = True # Landscape self.name = name or self.name; self.order = 23; self.parent:Callback = None # Landscape
[docs] def startRun(self): self.originalParams = self.l.model.exportParams() # Landscape
[docs] def endRun(self): self.l.model.importParams(self.originalParams) # Landscape
[docs] def startPass(self): # Landscape next(self.iter) # Landscape for param, og, v1, v2 in zip(self.l.model.parameters(), self.originalParams, *self.vs): # Landscape param.data = og + self.x * v1 + self.y * v2 # Landscape
[docs] def endLoss(self): # Landscape prop = self.propertyF(self.l) # Landscape self.zs[self.ix, self.iy] = prop if prop == prop else 0 # check for nan # Landscape if self.l.batch % 10: print(f"\rProgress: {round(100*(self.ix+self.iy/res)/res)}%, {round(time.time()-self.beginTime)}s ", end="") # Landscape
[docs] def startBackward(self): return True # Landscape
[docs] def startStep(self): return True # Landscape
[docs] def startZeroGrad(self): return True # Landscape
def __iter__(self): # Landscape """This one is the "core running loop", if you'd like to say so. Because this needs to be sort of event-triggered (by checkpoint "startPass"), so kinda have to put this into an iterator so that it's not the driving thread.""" # Landscape self.zss = [] # debug data # Landscape for i, (scale, ax) in enumerate(zip(scales, self.axes)): # Landscape a = torch.linspace(-scale, scale, res) # Landscape xs, ys = np.meshgrid(a, a); self.zs = np.empty((res, res)) # Landscape xs = torch.tensor(xs); ys = torch.tensor(ys) # Landscape for ix in range(res): # Landscape for iy in range(res): # Landscape self.x = xs[ix, iy]; self.y = ys[ix, iy] # Landscape self.ix, self.iy = ix, iy; yield True # Landscape self.zs[self.zs == float("inf")] = 0 # Landscape ax.plot_surface(xs, ys, self.zs, cmap=plt.cm.coolwarm) # Landscape self.zss.append(self.zs) # Landscape print(f" {i+1}/8 Finished [{-scale}, {scale}] range ", end="") # Landscape raise k1lib.CancelRunException("Landscape finished") # Landscape
[docs] def plot(self): # Landscape """Creates the landscapes and show plots""" # Landscape self.suspended = False; self.iter = iter(self); self.beginTime = time.time() # Landscape def inner(): # Landscape self.vs = [self.l.model.getParamsVector(), self.l.model.getParamsVector()] # Landscape fig, axes = plt.subplots(2, 4, subplot_kw={"projection": "3d"}, figsize=(16, 8), dpi=120) # Landscape self.axes = axes.flatten(); self.l.run(1000000) # Landscape try: # Landscape with self.cbs.suspendEval(), torch.no_grad(): inner() # Landscape except: pass # Landscape self.suspended = True; self.iter = None # Landscape
def __repr__(self): # Landscape return f"""{super()._reprHead}, use... - l.plot(): to plot everything {super()._reprCan}""" # Landscape