Source code for k1lib._learner

# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
import k1lib, dill, traceback
from k1lib.callbacks import Cbs
from typing import Union
from time import time as _time
try: import torch; import torch.nn as nn; hasTorch = True
except:
    torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
    nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["CancelRunException", "CancelEpochException", "CancelBatchException",
           "Learner"]
[docs]class CancelRunException(Exception): # CancelRunException """Used in core training loop, to skip the run entirely""" # CancelRunException pass # CancelRunException
[docs]class CancelEpochException(Exception): # CancelEpochException """Used in core training loop, to skip to next epoch""" # CancelEpochException pass # CancelEpochException
[docs]class CancelBatchException(Exception): # CancelBatchException """Used in core training loop, to skip to next batch""" # CancelBatchException pass # CancelBatchException
def _tab(text:Union[list, str], pad=" ") -> Union[list, str]: # _tab if isinstance(text, str): # this is old function that's replaced in main lib, but still useful # _tab return "\n".join([pad + line for line in text.split("\n")]) # _tab else: return [pad + line for line in text] # _tab
[docs]class Learner: # Learner def __init__(self): # Learner self._model = None; self._data = None; self._opt = None # Learner self._cbs = None; self.fileName = None # Learner self.css = "*"; self.exceptionRaised = None # slowly pops # Learner self.cbs = k1lib.Callbacks().withBasics().withQOL().withAdvanced() # Learner @property # Learner def model(self): # Learner """Set this to change the model to run""" # Learner return self._model # Learner @model.setter # Learner def model(self, model): self._model = model # Learner @property # Learner def data(self): # Learner """Set this to change the data (list of 2 dataloader) to run against.""" # Learner return self._data # Learner @data.setter # Learner def data(self, data): self._data = data # Learner @property # Learner def opt(self): # Learner """Set this to change the optimizer. If you're making your own optimizers, beware to follow the PyTorch's style guide as there are callbacks that modifies optimizer internals while training like :class:`k1lib.schedule.ParamScheduler`.""" # Learner return self._opt # Learner @opt.setter # Learner def opt(self, opt): self._opt = opt # Learner @property # Learner def cbs(self): # Learner """The :class:`~k1lib.callbacks.callbacks.Callbacks` object. Initialized to include all the common callbacks. You can set a new one if you want to.""" # Learner return self._cbs # Learner @cbs.setter # Learner def cbs(self, cbs): cbs.l = self; self._cbs = cbs # Learner @property # Learner def css(self) -> str: # Learner """The css selector string. Set this to select other parts of the network. After setting, you can access the selector like this: :code:`l.selector` See also: :class:`~k1lib.selector.ModuleSelector`""" # Learner return self._css # Learner @css.setter # Learner def css(self, css:str): # Learner self._css = css # Learner if self.model != None: self.selector = k1lib.selector.select(self.model, self.css) # Learner @property # Learner def lossF(self): # Learner """Set this to specify a loss function.""" # Learner raise NotImplementedError("lossF actually doesn't really exist. Used to exist as a core part of Learner, but then has been converted to Cbs.LossF") # Learner @lossF.setter # Learner def lossF(self, lossF): # Learner if hasattr(self.cbs, "LossF"): self.cbs.LossF.lossF = lossF # Learner else: self.cbs.add(Cbs.LossF(lossF)) # Learner def __getattr__(self, attr): # Learner if attr == "cbs": raise AttributeError() # Learner return getattr(self.cbs, attr) # Learner def __getstate__(self): # Learner answer = dict(self.__dict__); answer.pop("selector", None) # Learner answer.pop("_data", None); return answer # Learner def __setstate__(self, state): # Learner self.__dict__.update(state) # Learner self.__dict__["_data"] = None # Learner self.css = self.css; self.cbs.l = self # Learner
[docs] def evaluate(self): pass # supposed to be overriden, to provide functionality here # Learner
@property # Learner def _warnings(self): # Learner warnings = "Warning: no model yet. Set using `l.model = ...`\n" if self.model == None else "" # Learner lossClasses = tuple([*k1lib.Callback.lossCls]) # Learner lossFnCbs = [True for cb in self.cbs if isinstance(cb, lossClasses)] # Learner warnings += "Warning: no loss function callback detected (or you set `lossF` already but then erased all callbacks)! Set using `l.lossF = ...` or `l.cbs.add(Cbs.LossF(...))`\n" if len(lossFnCbs) == 0 else "" # Learner warnings += "Warning: no data yet. Set using `l.data = ...`\n" if self.data == None else "" # Learner warnings += "Warning: no optimizer yet. Set using `l.opt = ...`\n" if self.opt == None else "" # Learner if warnings != "": warnings += "\n\n" # Learner return warnings # Learner def __dir__(self): # Learner answer = list(super().__dir__()) # Learner answer.extend(self.cbs.cbsDict.keys()); return answer # Learner def __repr__(self): # Learner return f"""{self._warnings}l.model:\n{_tab(k1lib.limitLines(str(self.model)))} l.opt:\n{_tab(k1lib.limitLines(str(self.opt)))} l.cbs:\n{_tab(k1lib.limitLines(self.cbs.__repr__()))} Use... - l.model = ...: to specify a nn.Module object - l.data = ...: to specify data object - l.opt = ...: to specify an optimizer - l.lossF = ...: to specify a loss function - l.css = ...: to select modules using CSS. "#root" for root model - l.cbs = ...: to use a custom `Callbacks` object - l.selector: to get the modules selected by `l.css` - l.run(epochs): to run the network - l.Loss: to get a specific callback, this case "Loss"\n\n""" # Learner
@k1lib.patch(Learner) # Learner def save(self, fileName:str=None): # save """Saves this :class:`Learner` to file. See also: :meth:`load`. Does not save the ``data`` object, because that's potentially very big. Example:: l = k1.Learner() # saves learner to "skip1_128bs.pth" and model to "skip1_128bs.model.pth" l.save("skip1_128bs") :param fileName: name to save file into""" # save torch.save(self, f"{fileName}.pth", pickle_module=dill) # save torch.save(self.model, f"{fileName}.model.pth", pickle_module=dill) # save print(f"Saved to {fileName}") # save @k1lib.patch(Learner, static=True) # save def load(fileName:str=None): # load """Loads a :class:`Learner` from a file. See also: :meth:`save`. Example:: # this will load up learner in file "skip1_128bs.pth" l = k1.Learner.load("skip1_128bs") :param fileName: if empty, then will prompt for file name""" # load f = fileName or input("Enter learner file name to load:") # load print(f"Loaded from {f}"); return torch.load(f"{f}.pth", pickle_module=dill) # load @k1lib.patch(Learner) # load def _run1Batch(self): # _run1Batch self.cbs("startBatch") # _run1Batch try: # _run1Batch self.cbs("startPass", "inPass", "endPass") # _run1Batch self.cbs("startLoss", "inLoss", "endLoss") # _run1Batch if not self.cbs("startBackward"): self.lossG.backward() # _run1Batch if not self.cbs("startStep"): self.opt.step() # _run1Batch if not self.cbs("startZeroGrad"): self.opt.zero_grad(set_to_none=True) # _run1Batch except k1lib.CancelBatchException as ex: # _run1Batch self.cbs("cancelBatch"); print(f"Batch cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "") # _run1Batch except (k1lib.CancelEpochException, k1lib.CancelRunException) as ex: # _run1Batch # makes sure cancelBatch and endBatch gets called, for potential # _run1Batch # cleanups, then reraise the exception # _run1Batch self.cbs("cancelBatch", "endBatch"); raise ex # _run1Batch self.cbs("endBatch") # _run1Batch class DI: # data interceptor, just to record data loading times # DI def __init__(self, l:Learner, data): self.l = l; self.data = data # DI def __len__(self): return len(self.data) # DI def __iter__(self): # DI try: # DI data = iter(self.data); timings = self.l.cbs.timings # DI while True: # DI beginTime = _time(); d = next(data) # DI timings.loadData += _time() - beginTime; yield d # DI except StopIteration: pass # DI @k1lib.patch(Learner) # DI def _run1Epoch(self): # _run1Epoch self.cbs("startEpoch") # _run1Epoch try: # _run1Epoch train, valid = self.data; train = DI(self, train); valid = DI(self, valid) # _run1Epoch try: self.batches = len(train) + len(valid) # _run1Epoch except: pass # _run1Epoch self.model.train() # _run1Epoch for self.batch, (self.xb, self.yb, *self.metab) in enumerate(train): # _run1Epoch self._run1Batch() # _run1Epoch trainLen = self.batch + 1 # _run1Epoch if not self.cbs("startValidBatches"): # _run1Epoch self.model.eval(); # _run1Epoch for self.batch, (self.xb, self.yb, *self.metab) in enumerate(valid): # _run1Epoch self.batch += trainLen; self._run1Batch() # _run1Epoch if self.batches is None: self.batches = self.batch + 1 # _run1Epoch except k1lib.CancelEpochException as ex: # _run1Epoch self.cbs("cancelEpoch"); print(f"Epoch cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "") # _run1Epoch except k1lib.CancelRunException as ex: # _run1Epoch self.cbs("cancelEpoch", "endEpoch"); raise ex # _run1Epoch self.cbs("endEpoch") # _run1Epoch @k1lib.patch(Learner) # _run1Epoch def run(self, epochs:int, batches:int=None): # run """Main run function. :param epochs: number of epochs to run. 1 epoch is the length of the dataset :param batches: if set, then cancels the epoch after reaching the specified batch""" # run if self._warnings != "": # run if not input(f"""You still have these warnings:\n\n{self._warnings} Do you want to continue? (y/n) """).lower().startswith("y"): # run print("Run ended"); return # run self.epochs = int(epochs); self.batches = None # run self.css = self.css # update module selector # run with self.cbs.context(): # run if batches is not None: self.cbs.add(Cbs.BatchLimit(int(batches))) # run self.cbs("startRun") # run try: # run for self.epoch in range(self.epochs): self._run1Epoch() # run except k1lib.CancelRunException as ex: # run self.cbs("cancelRun"); print(f"Run cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "") # run self.cbs("endRun"); return self # run @k1lib.patch(Learner) # run def __call__(self, xb, yb=None): # __call__ """Executes just a small batch. Convenience method to query how the network is doing. :param xb: x batch :param yb: y batch. If specified, return (y, loss), else return y alone """ # __call__ oldData = self.data; self.data = [[(xb, (yb or torch.tensor(0)))], []] # __call__ with self.cbs.suspendEval(), self.cbs.context(): # __call__ ex = lambda _: k1lib.raiseEx(k1lib.CancelBatchException) # __call__ self.cbs.add(k1lib.Callback().withCheckpoint("startLoss" if yb is None else "startBackward", ex)) # __call__ self.run(1, 1) # __call__ self.data = oldData; return self.y if yb is None else (self.y, self.loss) # __call__ @k1lib.patch(Learner) # __call__ def evaluate(self): # evaluate """Function to visualize quickly how the network is doing. Undefined by default, just placed here as a convention, so you have to do something like this:: l = k1lib.Learner() def evaluate(self): xbs, ybs, ys = self.Recorder.record(1, 3) plt.plot(torch.vstack(xbs), torch.vstack(ys)) l.evaluate = partial(evaluate(l)) """ # evaluate raise NotImplementedError("You have to define evaluate() by yourself") # evaluate from k1lib.cli import * # evaluate @k1lib.patch(Learner, static=True) # evaluate def sample() -> Learner: # sample """Creates an example learner, just for simple testing stuff anywhere. The network tries to learn the function y=x. Only bare minimum callbacks are included.""" # sample l = Learner(); x = torch.linspace(-5, 5, 1000) # sample l.data = [x, x] | transpose() | randomize(None) | splitW() | (repeatFrom() | randomize() | batched(32) | (transpose() | toTensor()).all()).all() | stagger.tv(300) | toList() # sample class Model(torch.nn.Module): # sample def __init__(self): # sample super().__init__() # sample self.lin1 = k1lib.knn.LinBlock(1, 3) # sample self.lin2 = nn.Linear(3, 1) # sample def forward(self, x): # sample return ((x[:, None] + 2) | self.lin1 | self.lin2).squeeze() # sample l.model = Model(); l.cbs = k1lib.Callbacks().add(Cbs.CoreNormal()).add(Cbs.Loss()).add(Cbs.ProgressBar()) # sample l.lossF = lambda y, yb: ((y - yb) ** 2).sum() # sample l.opt = torch.optim.Adam(l.model.parameters(), lr=3e-3); return l # sample