# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
import k1lib, dill, traceback
from k1lib.callbacks import Cbs
from typing import Union
from time import time as _time
try: import torch; import torch.nn as nn; hasTorch = True
except:
torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["CancelRunException", "CancelEpochException", "CancelBatchException",
"Learner"]
[docs]class CancelRunException(Exception): # CancelRunException
"""Used in core training loop, to skip the run entirely""" # CancelRunException
pass # CancelRunException
[docs]class CancelEpochException(Exception): # CancelEpochException
"""Used in core training loop, to skip to next epoch""" # CancelEpochException
pass # CancelEpochException
[docs]class CancelBatchException(Exception): # CancelBatchException
"""Used in core training loop, to skip to next batch""" # CancelBatchException
pass # CancelBatchException
def _tab(text:Union[list, str], pad=" ") -> Union[list, str]: # _tab
if isinstance(text, str): # this is old function that's replaced in main lib, but still useful # _tab
return "\n".join([pad + line for line in text.split("\n")]) # _tab
else: return [pad + line for line in text] # _tab
[docs]class Learner: # Learner
def __init__(self): # Learner
self._model = None; self._data = None; self._opt = None # Learner
self._cbs = None; self.fileName = None # Learner
self.css = "*"; self.exceptionRaised = None # slowly pops # Learner
self.cbs = k1lib.Callbacks().withBasics().withQOL().withAdvanced() # Learner
@property # Learner
def model(self): # Learner
"""Set this to change the model to run""" # Learner
return self._model # Learner
@model.setter # Learner
def model(self, model): self._model = model # Learner
@property # Learner
def data(self): # Learner
"""Set this to change the data (list of 2 dataloader) to run against.""" # Learner
return self._data # Learner
@data.setter # Learner
def data(self, data): self._data = data # Learner
@property # Learner
def opt(self): # Learner
"""Set this to change the optimizer. If you're making your own
optimizers, beware to follow the PyTorch's style guide as there are callbacks
that modifies optimizer internals while training like
:class:`k1lib.schedule.ParamScheduler`.""" # Learner
return self._opt # Learner
@opt.setter # Learner
def opt(self, opt): self._opt = opt # Learner
@property # Learner
def cbs(self): # Learner
"""The :class:`~k1lib.callbacks.callbacks.Callbacks` object. Initialized to
include all the common callbacks. You can set a new one if you want to.""" # Learner
return self._cbs # Learner
@cbs.setter # Learner
def cbs(self, cbs): cbs.l = self; self._cbs = cbs # Learner
@property # Learner
def css(self) -> str: # Learner
"""The css selector string. Set this to select other parts of the network.
After setting, you can access the selector like this: :code:`l.selector`
See also: :class:`~k1lib.selector.ModuleSelector`""" # Learner
return self._css # Learner
@css.setter # Learner
def css(self, css:str): # Learner
self._css = css # Learner
if self.model != None: self.selector = k1lib.selector.select(self.model, self.css) # Learner
@property # Learner
def lossF(self): # Learner
"""Set this to specify a loss function.""" # Learner
raise NotImplementedError("lossF actually doesn't really exist. Used to exist as a core part of Learner, but then has been converted to Cbs.LossF") # Learner
@lossF.setter # Learner
def lossF(self, lossF): # Learner
if hasattr(self.cbs, "LossF"): self.cbs.LossF.lossF = lossF # Learner
else: self.cbs.add(Cbs.LossF(lossF)) # Learner
def __getattr__(self, attr): # Learner
if attr == "cbs": raise AttributeError() # Learner
return getattr(self.cbs, attr) # Learner
def __getstate__(self): # Learner
answer = dict(self.__dict__); answer.pop("selector", None) # Learner
answer.pop("_data", None); return answer # Learner
def __setstate__(self, state): # Learner
self.__dict__.update(state) # Learner
self.__dict__["_data"] = None # Learner
self.css = self.css; self.cbs.l = self # Learner
[docs] def evaluate(self): pass # supposed to be overriden, to provide functionality here # Learner
@property # Learner
def _warnings(self): # Learner
warnings = "Warning: no model yet. Set using `l.model = ...`\n" if self.model == None else "" # Learner
lossClasses = tuple([*k1lib.Callback.lossCls]) # Learner
lossFnCbs = [True for cb in self.cbs if isinstance(cb, lossClasses)] # Learner
warnings += "Warning: no loss function callback detected (or you set `lossF` already but then erased all callbacks)! Set using `l.lossF = ...` or `l.cbs.add(Cbs.LossF(...))`\n" if len(lossFnCbs) == 0 else "" # Learner
warnings += "Warning: no data yet. Set using `l.data = ...`\n" if self.data == None else "" # Learner
warnings += "Warning: no optimizer yet. Set using `l.opt = ...`\n" if self.opt == None else "" # Learner
if warnings != "": warnings += "\n\n" # Learner
return warnings # Learner
def __dir__(self): # Learner
answer = list(super().__dir__()) # Learner
answer.extend(self.cbs.cbsDict.keys()); return answer # Learner
def __repr__(self): # Learner
return f"""{self._warnings}l.model:\n{_tab(k1lib.limitLines(str(self.model)))}
l.opt:\n{_tab(k1lib.limitLines(str(self.opt)))}
l.cbs:\n{_tab(k1lib.limitLines(self.cbs.__repr__()))}
Use...
- l.model = ...: to specify a nn.Module object
- l.data = ...: to specify data object
- l.opt = ...: to specify an optimizer
- l.lossF = ...: to specify a loss function
- l.css = ...: to select modules using CSS. "#root" for root model
- l.cbs = ...: to use a custom `Callbacks` object
- l.selector: to get the modules selected by `l.css`
- l.run(epochs): to run the network
- l.Loss: to get a specific callback, this case "Loss"\n\n""" # Learner
@k1lib.patch(Learner) # Learner
def save(self, fileName:str=None): # save
"""Saves this :class:`Learner` to file. See also: :meth:`load`. Does not
save the ``data`` object, because that's potentially very big. Example::
l = k1.Learner()
# saves learner to "skip1_128bs.pth" and model to "skip1_128bs.model.pth"
l.save("skip1_128bs")
:param fileName: name to save file into""" # save
torch.save(self, f"{fileName}.pth", pickle_module=dill) # save
torch.save(self.model, f"{fileName}.model.pth", pickle_module=dill) # save
print(f"Saved to {fileName}") # save
@k1lib.patch(Learner, static=True) # save
def load(fileName:str=None): # load
"""Loads a :class:`Learner` from a file. See also: :meth:`save`.
Example::
# this will load up learner in file "skip1_128bs.pth"
l = k1.Learner.load("skip1_128bs")
:param fileName: if empty, then will prompt for file name""" # load
f = fileName or input("Enter learner file name to load:") # load
print(f"Loaded from {f}"); return torch.load(f"{f}.pth", pickle_module=dill) # load
@k1lib.patch(Learner) # load
def _run1Batch(self): # _run1Batch
self.cbs("startBatch") # _run1Batch
try: # _run1Batch
self.cbs("startPass", "inPass", "endPass") # _run1Batch
self.cbs("startLoss", "inLoss", "endLoss") # _run1Batch
if not self.cbs("startBackward"): self.lossG.backward() # _run1Batch
if not self.cbs("startStep"): self.opt.step() # _run1Batch
if not self.cbs("startZeroGrad"): self.opt.zero_grad(set_to_none=True) # _run1Batch
except k1lib.CancelBatchException as ex: # _run1Batch
self.cbs("cancelBatch"); print(f"Batch cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "") # _run1Batch
except (k1lib.CancelEpochException, k1lib.CancelRunException) as ex: # _run1Batch
# makes sure cancelBatch and endBatch gets called, for potential # _run1Batch
# cleanups, then reraise the exception # _run1Batch
self.cbs("cancelBatch", "endBatch"); raise ex # _run1Batch
self.cbs("endBatch") # _run1Batch
class DI: # data interceptor, just to record data loading times # DI
def __init__(self, l:Learner, data): self.l = l; self.data = data # DI
def __len__(self): return len(self.data) # DI
def __iter__(self): # DI
try: # DI
data = iter(self.data); timings = self.l.cbs.timings # DI
while True: # DI
beginTime = _time(); d = next(data) # DI
timings.loadData += _time() - beginTime; yield d # DI
except StopIteration: pass # DI
@k1lib.patch(Learner) # DI
def _run1Epoch(self): # _run1Epoch
self.cbs("startEpoch") # _run1Epoch
try: # _run1Epoch
train, valid = self.data; train = DI(self, train); valid = DI(self, valid) # _run1Epoch
try: self.batches = len(train) + len(valid) # _run1Epoch
except: pass # _run1Epoch
self.model.train() # _run1Epoch
for self.batch, (self.xb, self.yb, *self.metab) in enumerate(train): # _run1Epoch
self._run1Batch() # _run1Epoch
trainLen = self.batch + 1 # _run1Epoch
if not self.cbs("startValidBatches"): # _run1Epoch
self.model.eval(); # _run1Epoch
for self.batch, (self.xb, self.yb, *self.metab) in enumerate(valid): # _run1Epoch
self.batch += trainLen; self._run1Batch() # _run1Epoch
if self.batches is None: self.batches = self.batch + 1 # _run1Epoch
except k1lib.CancelEpochException as ex: # _run1Epoch
self.cbs("cancelEpoch"); print(f"Epoch cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "") # _run1Epoch
except k1lib.CancelRunException as ex: # _run1Epoch
self.cbs("cancelEpoch", "endEpoch"); raise ex # _run1Epoch
self.cbs("endEpoch") # _run1Epoch
@k1lib.patch(Learner) # _run1Epoch
def run(self, epochs:int, batches:int=None): # run
"""Main run function.
:param epochs: number of epochs to run. 1 epoch is the length of the dataset
:param batches: if set, then cancels the epoch after reaching the specified batch""" # run
if self._warnings != "": # run
if not input(f"""You still have these warnings:\n\n{self._warnings}
Do you want to continue? (y/n) """).lower().startswith("y"): # run
print("Run ended"); return # run
self.epochs = int(epochs); self.batches = None # run
self.css = self.css # update module selector # run
with self.cbs.context(): # run
if batches is not None: self.cbs.add(Cbs.BatchLimit(int(batches))) # run
self.cbs("startRun") # run
try: # run
for self.epoch in range(self.epochs): self._run1Epoch() # run
except k1lib.CancelRunException as ex: # run
self.cbs("cancelRun"); print(f"Run cancelled: {ex}.", end="\n" if k1lib.settings.cancelRun_newLine else "") # run
self.cbs("endRun"); return self # run
@k1lib.patch(Learner) # run
def __call__(self, xb, yb=None): # __call__
"""Executes just a small batch. Convenience method to query how the network is
doing.
:param xb: x batch
:param yb: y batch. If specified, return (y, loss), else return y alone
""" # __call__
oldData = self.data; self.data = [[(xb, (yb or torch.tensor(0)))], []] # __call__
with self.cbs.suspendEval(), self.cbs.context(): # __call__
ex = lambda _: k1lib.raiseEx(k1lib.CancelBatchException) # __call__
self.cbs.add(k1lib.Callback().withCheckpoint("startLoss" if yb is None else "startBackward", ex)) # __call__
self.run(1, 1) # __call__
self.data = oldData; return self.y if yb is None else (self.y, self.loss) # __call__
@k1lib.patch(Learner) # __call__
def evaluate(self): # evaluate
"""Function to visualize quickly how the network is doing. Undefined by default,
just placed here as a convention, so you have to do something like this::
l = k1lib.Learner()
def evaluate(self):
xbs, ybs, ys = self.Recorder.record(1, 3)
plt.plot(torch.vstack(xbs), torch.vstack(ys))
l.evaluate = partial(evaluate(l))
""" # evaluate
raise NotImplementedError("You have to define evaluate() by yourself") # evaluate
from k1lib.cli import * # evaluate
@k1lib.patch(Learner, static=True) # evaluate
def sample() -> Learner: # sample
"""Creates an example learner, just for simple testing stuff anywhere. The
network tries to learn the function y=x. Only bare minimum callbacks are included.""" # sample
l = Learner(); x = torch.linspace(-5, 5, 1000) # sample
l.data = [x, x] | transpose() | randomize(None) | splitW() | (repeatFrom() | randomize() | batched(32) | (transpose() | toTensor()).all()).all() | stagger.tv(300) | toList() # sample
class Model(torch.nn.Module): # sample
def __init__(self): # sample
super().__init__() # sample
self.lin1 = k1lib.knn.LinBlock(1, 3) # sample
self.lin2 = nn.Linear(3, 1) # sample
def forward(self, x): # sample
return ((x[:, None] + 2) | self.lin1 | self.lin2).squeeze() # sample
l.model = Model(); l.cbs = k1lib.Callbacks().add(Cbs.CoreNormal()).add(Cbs.Loss()).add(Cbs.ProgressBar()) # sample
l.lossF = lambda y, yb: ((y - yb) ** 2).sum() # sample
l.opt = torch.optim.Adam(l.model.parameters(), lr=3e-3); return l # sample