# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, warnings
from typing import List, Callable
try: import torch; hasTorch = True
except: torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["ConfusionMatrix"]
[docs]@k1lib.patch(Cbs)
class ConfusionMatrix(Callback): # ConfusionMatrix
" " # ConfusionMatrix
categories:List[str] # ConfusionMatrix
"""String categories for displaying the matrix. You can set this
so that it displays what you want, in case this Callback is included
automatically.""" # ConfusionMatrix
matrix:torch.Tensor # ConfusionMatrix
"""The recorded confusion matrix.""" # ConfusionMatrix
[docs] def __init__(self, categories:List[str]=None, condF:Callable[["ConfusionMatrix"], bool]=lambda _: True): # ConfusionMatrix
"""Records what categories the network is confused the most. Expected
variables in :class:`~k1lib.Learner`:
- preds: long tensor with categories id of batch before checkpoint ``endLoss``.
Auto-included in :class:`~k1lib.callbacks.lossFunctions.accuracy.AccF` and
:class:`~k1lib.callbacks.lossFunctions.shorts.LossNLLCross`.
:param categories: optional list of category names
:param condF: takes in this cb's and returns whether to record at this
particular `endLoss` checkpoint.""" # ConfusionMatrix
super().__init__(); self.categories = categories # ConfusionMatrix
self.n = len(categories or []) or 2; self.condF = condF # ConfusionMatrix
self.matrix = torch.zeros(self.n, self.n); # ConfusionMatrix
self.wipeOnAdd = False # flag to wipe matrix on adding new data points # ConfusionMatrix
def _adapt(self, idxs): # ConfusionMatrix
"""Adapts the internal matrix so that it supports new categories""" # ConfusionMatrix
m = idxs.max().item() + 1 # ConfusionMatrix
if m > self.n: # +1 because max index = len() - 1 # ConfusionMatrix
matrix = torch.zeros(m, m) # ConfusionMatrix
matrix[:self.n, :self.n] = self.matrix # ConfusionMatrix
self.matrix = matrix; self.n = len(self.matrix) # ConfusionMatrix
self.matrix = self.matrix.to(idxs.device); return idxs # ConfusionMatrix
[docs] def startEpoch(self): self.wipeOnAdd = True # ConfusionMatrix
[docs] def endLoss(self): # ConfusionMatrix
if self.condF(self): # ConfusionMatrix
if self.wipeOnAdd: # ConfusionMatrix
self.matrix = torch.zeros(self.n, self.n); # ConfusionMatrix
self.wipeOnAdd = False; # ConfusionMatrix
yb = self._adapt(self.l.yb); preds = self._adapt(self.l.preds) # ConfusionMatrix
self.matrix[yb, preds] += 1 # ConfusionMatrix
@property # ConfusionMatrix
def goodMatrix(self) -> torch.Tensor: # ConfusionMatrix
"""Clears all inf, nans and whatnot from the matrix, then returns it.""" # ConfusionMatrix
n = self.n; m = self.matrix # ConfusionMatrix
while m.hasNan() or m.hasInfs(): # ConfusionMatrix
n -= 1; m = m[:n, :n] # ConfusionMatrix
if n != self.n: warnings.warn(f"Originally, the confusion matrix has {self.n} categories, now it has {n} only, after filtering, because there are some nans and infinite values.") # ConfusionMatrix
if self.categories is not None: # ConfusionMatrix
n = len(self.categories); m = m[:n, :n] # ConfusionMatrix
return m/m.max(dim=1).values[:,None] # ConfusionMatrix
[docs] def plot(self): # ConfusionMatrix
"""Plots everything""" # ConfusionMatrix
k1lib.viz.confusionMatrix(self.goodMatrix, self.categories or list(range(self.n))) # ConfusionMatrix
def __repr__(self): # ConfusionMatrix
return f"""{super()._reprHead}, use...
- l.plot(): to plot everything
{super()._reprCan}""" # ConfusionMatrix