# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""For not very complicated accuracies functions"""
from ..callbacks import Callback, Callbacks, Cbs
from typing import Callable, Tuple
import k1lib
try: import torch; hasTorch = True
except: torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["AccF"]
AccFSig = Callable[[Tuple[torch.Tensor, torch.Tensor]], float]
PredFSig = Callable[[torch.Tensor], torch.Tensor]
[docs]@k1lib.patch(Cbs)
class AccF(Callback): # AccF
" " # AccF
[docs] def __init__(self, predF:PredFSig=None, accF:AccFSig=None, integrations:bool=True, variable:str="accuracy", hookToLearner:bool=True): # AccF
"""Generic accuracy function.
Built in default accuracies functions are fine, if you don't do something too
dramatic/different. Expected variables in :class:`~k1lib.Learner`:
- y: :class:`~torch.Tensor` of shape (\*N, C)
- yb: :class:`~torch.Tensor` of shape (\*N,)
Deposits variables into :class:`~k1lib.Learner`:
- preds: detached, batched tensor output of ``predF``
- accuracies: detached, batched tensor output of ``accF``
- accuracy: detached, single float, mean of ``accuracies``
Where:
- N is the batch size. Can be multidimensional, but has to agree between ``y`` and ``yb``
- C is the number of categories
:param predF: takes in ``y``, returns predictions (tensor with int elements indicating the categories)
:param accF: takes in ``(predictions, yb)``, returns accuracies (tensor with 0 or 1 elements)
:param integrations: whether to integrate :class:`~k1lib.callbacks.confusionMatrix.ConfusionMatrix` or not.
:param variable: variable to deposit into Learner""" # AccF
super().__init__(); self.order = 10; self.integrations = integrations; self.ownsConMat = False; self.hookToLearner = hookToLearner # AccF
self.predF = predF or (lambda y: y.argmax(-1)) # AccF
self.accF = accF or (lambda p, yb: (p == yb)+0); self.variable = variable # AccF
[docs] def attached(self): # AccF
if self.integrations: # AccF
if "ConfusionMatrix" not in self.cbs: # AccF
self.conMatCb = Cbs.ConfusionMatrix() # AccF
self.cbs.add(self.conMatCb); self.ownsConMat = True # AccF
else: self.conMatCb = self.cbs.ConfusionMatrix # AccF
[docs] def endLoss(self): # AccF
preds = self.predF(self.l.y); accs = self.accF(preds, self.l.yb); # AccF
if self.hookToLearner: # AccF
self.l.preds = preds.detach() # AccF
self.l.accuracies = accs.detach() # AccF
self.l.__dict__[self.variable] = accs.float().mean().item() # AccF
[docs] def detach(self): # AccF
super().detach() # AccF
if self.conMatCb != None: # AccF
if self.ownsConMat: self.conMatCb.detach() # AccF
self.conMatCb = None # AccF