# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""This is for core callbacks, that defines how everything is going to go"""
from .callbacks import Callback, Callbacks, Cbs
import k1lib; from typing import List, Tuple, Dict, Iterator, Union, Any, Callable
try: import torch; hasTorch = True
except: torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["CoreNormal", "CoreRNN"]
[docs]@k1lib.patch(Cbs)
class CoreNormal(Callback): # CoreNormal
"""Just a normal, typical feed forward pass.
Deposits variables into :class:`~k1lib.Learner` at checkpoint ``inPass``:
- y: attached result tensor after passing through model""" # CoreNormal
def inPass(self): # CoreNormal
self.l.y = self.l.model(self.l.xb) # CoreNormal
[docs]@k1lib.patch(Cbs) # CoreNormal
class CoreRNN(Callback): # CoreRNN
"""RNN forward pass.
Expected variables from :attr:`k1lib.Learner.model`:
- initHidden: function takes in batch size, returns init hidden tensor
Deposits variables into :class:`~k1lib.Learner` at checkpoint ``inPass``, more
specifically ``rnnPass``:
- y: attached result tensor after pass (``inPass``), after character pass (``rnnPass``)
""" # CoreRNN
def startBatch(self): # CoreRNN
self.hx = self.l.model.initHidden(self.l.xb.shape[-2]) # CoreRNN
def inPass(self): # CoreRNN
self.hx = self.hx.to(self.l.xb.device) # CoreRNN
for item in self.l.xb: # CoreRNN
self.l.y, self.hx = self.l.model(item, self.hx) # CoreRNN
self.cbs("rnnPass") # CoreRNN