Source code for k1lib.selector

# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""
This module is for selecting a subnetwork using CSS so that you can do special
things to them. Checkout the tutorial section for a walkthrough. This is exposed
automatically with::

   from k1lib.imports import *
   selector.select # exposed
"""
import k1lib, re
from k1lib import cli
from typing import List, Tuple, Dict, Union, Any, Iterator, Callable
from contextlib import contextmanager; from functools import partial
try: import torch; from torch import nn; hasTorch = True
except:
    torch = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {}))
    nn = k1lib.Object().withAutoDeclare(lambda: type("RandomClass", (object, ), {})); hasTorch = False
__all__ = ["ModuleSelector", "preprocess", "select"]
[docs]def preprocess(selectors:str, defaultProp="*") -> List[str]: # preprocess r"""Removes all quirkly features allowed by the css language, and outputs nice lines. Example:: # returns ["a:f", "a:g,h", "b:g,h", "t:*"] selector.preprocess("a:f; a, b: g,h; t") :param selectors: single css selector string. Statements separated by "\\n" or ";" :param defaultProp: default property, if statement doesn't have one""" # preprocess # filtering unwanted characters and quirky spaces # preprocess lines = [e for l in selectors.split("\n") for e in l.split(";")] # preprocess selectors = [re.sub("(^\s+)|(\s+$)", "", re.sub("\s\s+", " ", line)).replace(" >", ">").replace("> ", ">").replace(" :", ":").replace(": ", ":").replace(" ,", ",").replace(", ", ",").replace(";", "\n").replace(" \n", "\n").replace("\n ", "\n") for line in lines if line != ""] # preprocess # adding "*" to all selectors with no props specified # preprocess selectors = [selector if ":" in selector else f"{selector}:{defaultProp}" for selector in selectors] # preprocess # expanding comma-delimited selectors # preprocess return [f"{segment}:{selector.split(':')[1]}" for selector in selectors for segment in selector.split(":")[0].split(",")] # preprocess
def _getParts(s:str): return [a for elem in s.split(":")[0].split(">") if elem for a in elem.split(" ") if a] # _getParts def _getProps(s:str): return [elem for elem in s.split(":")[1].split(",") if elem] # _getProps _idxAuto = k1lib.AutoIncrement() # _getProps
[docs]class ModuleSelector: # empty methods so that Sphinx generates the docs in order # ModuleSelector props:List[str] # ModuleSelector """Properties of this :class:`ModuleSelector`""" # ModuleSelector idx:int # ModuleSelector """Unique id of this :class:`ModuleSelector` in the entire script. May be useful for module recognition""" # ModuleSelector nn:"torch.nn.Module" # ModuleSelector """The associated :class:`torch.nn.Module` of this :class:`ModuleSelector`""" # ModuleSelector def __init__(self, parent:"ModuleSelector", name:str, nn:"torch.nn.Module"): # ModuleSelector self.parent = parent; self.name = name; self.nn = nn # ModuleSelector self._children:Dict["ModuleSelector"] = {} # ModuleSelector self.props:List[str] = []; self.depth:int = 0 # ModuleSelector self.directSelectors:List[str] = [] # ModuleSelector self.indirectSelectors:List[str] = [] # ModuleSelector self.displayF:Callable[["ModuleSelector"], str] = lambda mS: ', '.join(mS.props) # ModuleSelector self.idx = _idxAuto() # ModuleSelector def deepestDepth(self): pass # ModuleSelector
[docs] def highlight(self, prop:str): # ModuleSelector """Highlights the specified prop when displaying the object.""" # ModuleSelector self.displayF = lambda self: (k1lib.fmt.txt.red if prop in self else k1lib.fmt.txt.identity)(', '.join(self.props)) # ModuleSelector return self # ModuleSelector
[docs] def __call__(self, *args, **kwargs): # ModuleSelector """Calls the internal :class:`torch.nn.Module`""" # ModuleSelector return self.nn(*args, **kwargs) # ModuleSelector
[docs] def __contains__(self): pass # ModuleSelector
[docs] def named_children(self): pass # ModuleSelector
[docs] def children(self): pass # ModuleSelector
[docs] def named_modules(self): pass # ModuleSelector
[docs] def modules(self): pass # ModuleSelector
def directParams(self): pass # ModuleSelector
[docs] def parse(self): pass # ModuleSelector
[docs] def apply(self): pass # ModuleSelector
[docs] def clearProps(self): pass # ModuleSelector
@property # ModuleSelector def displayF(self): # ModuleSelector """Function to display each ModuleSelector's lines. Default is just:: lambda mS: ", ".join(mS.props) """ # ModuleSelector return self._displayF # ModuleSelector @displayF.setter # ModuleSelector def displayF(self, f): # ModuleSelector def applyF(self): self._displayF = f # ModuleSelector self.apply(applyF) # ModuleSelector def __getattr__(self, attr): # ModuleSelector if attr.startswith("_"): raise AttributeError(attr) # ModuleSelector if attr in self._children: return self._children[attr] # ModuleSelector return self.directParams[attr] # ModuleSelector def __getitem__(self, idx): return getattr(self, str(idx)) # ModuleSelector
[docs] @staticmethod # ModuleSelector def sample() -> "ModuleSelector": # ModuleSelector """Create a new example :class:`ModuleSelector` that has a bit of hierarchy to them, with no css.""" # ModuleSelector return nn.Sequential(nn.Linear(3, 4), nn.Sequential(nn.Conv2d(3, 8, 3, 2), nn.ReLU(), nn.Linear(5, 6)), nn.Linear(7, 8)).select("") # ModuleSelector
[docs] def hookF(self): pass # ModuleSelector
[docs] def hookFp(self): pass # ModuleSelector
[docs] def hookB(self): pass # ModuleSelector
[docs] def freeze(self): pass # ModuleSelector
[docs] def unfreeze(self): pass # ModuleSelector
[docs]@k1lib.patch(nn.Module) # ModuleSelector def select(model:"torch.nn.Module", css:str="*") -> "k1lib.selector.ModuleSelector": # select """Creates a new ModuleSelector, in sync with a model. Example:: mS = selector.select(nn.Linear(3, 4), "#root:propA") Or, you can do it the more direct way:: mS = nn.Linear(3, 4).select("#root:propA") :param model: the :class:`torch.nn.Module` object to select from :param css: the css selectors""" # select root = ModuleSelector(None, "root", model) # select root.parse(preprocess(css)); return root # select
@k1lib.patch(ModuleSelector, name="apply") # select def _apply(self, f:Callable[[ModuleSelector], None]): # _apply """Applies a function to self and all child :class:`ModuleSelector`""" # _apply f(self) # _apply for child in self._children.values(): child.apply(f) # _apply @k1lib.patch(ModuleSelector, name="parse") # _apply def _parse(self, selectors:Union[List[str], str]) -> ModuleSelector: # _parse """Parses extra selectors. Clears all old selectors, but retain the props. Returns self. Example:: mS = selector.ModuleSelector.sample().parse("Conv2d:propA") # returns True "propA" in mS[1][0] :param selectors: can be the preprocessed list, or the unprocessed css string""" # _parse if isinstance(selectors, str): selectors = preprocess(selectors) # _parse self.directSelectors = []; self.indirectSelectors = [] # _parse ogSelectors = selectors # _parse if self.parent != None: # _parse selectors = [] + selectors + self.parent.indirectSelectors + self.parent.directSelectors # _parse self.indirectSelectors += self.parent.indirectSelectors # _parse self.depth = self.parent.depth + 1 # _parse for selector in selectors: # _parse parts = _getParts(selector) # _parse matches = parts[0] == self.nn.__class__.__name__ or parts[0] == "#" + self.name or parts[0] == "*" # _parse if len(parts) == 1: # _parse if matches: self.props += _getProps(selector) # _parse else: # _parse a = selector.find(">"); a = a if a > 0 else float("inf") # _parse b = selector.find(" "); b = b if b > 0 else float("inf") # _parse direct = a < b # _parse if matches: # _parse if direct: self.directSelectors.append(selector[a+1:]) # _parse else: self.indirectSelectors.append(selector[b+1:]) # _parse for name, mod in self.nn.named_children(): # _parse if name not in self._children: # _parse self._children[name] = ModuleSelector(self, name, mod) # _parse self._children[name].parse(ogSelectors) # _parse self.props = list(set(self.props)); return self # _parse @k1lib.patch(ModuleSelector) # _parse def __contains__(self, prop:str=None) -> bool: # __contains__ """Whether this :class:`ModuleSelector` has a specific prop. Example:: # returns True "b" in nn.Linear(3, 4).select("*:b") # returns False "h" in nn.Linear(3, 4).select("*:b") # returns True, "*" here means the ModuleSelector has any properties at all "*" in nn.Linear(3, 4).select("*:b")""" # __contains__ if "*" in self.props: return True # __contains__ if prop in self.props: return True # __contains__ if prop == "*" and len(self.props) > 0: return True # __contains__ return False # __contains__ @k1lib.patch(ModuleSelector) # __contains__ def named_children(self, prop:str=None) -> Iterator[Tuple[str, ModuleSelector]]: # named_children """Get all named direct childs. :param prop: Filter property. See also: :meth:`__contains__`""" # named_children if prop is None: return self._children.items() # named_children return ((k, v) for k, v in self._children.items() if prop in v) # named_children @k1lib.patch(ModuleSelector) # named_children def children(self, prop:str=None) -> Iterator[ModuleSelector]: # children """Get all direct childs. :param prop: Filter property. See also: :meth:`__contains__`""" # children return (x for _, x in self.named_children(prop)) # children @k1lib.patch(ModuleSelector, "directParams") # children @property # children def directParams(self) -> Dict[str, nn.Parameter]: # directParams """Dict params directly under this module""" # directParams return {name: param for name, param in self.nn.named_parameters() if "." not in name} # directParams @k1lib.patch(ModuleSelector) # directParams def named_modules(self, prop:str=None) -> Iterator[Tuple[str, ModuleSelector]]: # named_modules """Get all named child recursively. Example:: modules = list(nn.Sequential(nn.Linear(3, 4), nn.ReLU()).select().named_modules()) # return 3 len(modules) # return tuple ('0', <ModuleSelector of Linear>) modules[1] :param prop: Filter property. See also: :meth:`__contains__`""" # named_modules if prop != None: # named_modules yield from ((name, m) for name, m in self.named_modules() if prop in m) # named_modules return # named_modules yield self.name, self # named_modules for child in self._children.values(): yield from child.named_modules() # named_modules @k1lib.patch(ModuleSelector) # named_modules def modules(self, prop:str=None) -> Iterator[ModuleSelector]: # modules """Get all child recursively. :param prop: Filter property. See also: :meth:`__contains__`""" # modules for name, x in self.named_modules(prop): yield x # modules @k1lib.patch(ModuleSelector) # modules def clearProps(self) -> "ModuleSelector": # clearProps """Clears all existing props of this and all descendants :class:`ModuleSelector`. Example:: # returns False "b" in nn.Linear(3, 4).select("*:b").clearProps()""" # clearProps def applyF(self): self.props = [] # clearProps self.apply(applyF); return self # clearProps @k1lib.patch(ModuleSelector, name="deepestDepth") # clearProps @property # clearProps def deepestDepth(self): # deepestDepth """Deepest depth of the tree. If self doesn't have any child, then depth is 0""" # deepestDepth if len(self._children) == 0: return 0 # deepestDepth return 1 + max([child.deepestDepth for child in self._children.values()]) # deepestDepth @k1lib.patch(ModuleSelector) # deepestDepth def __repr__(self, intro:bool=True, header:Union[str, Tuple[str]]="", footer="", tabs:int=None): # __repr__ """ :param intro: whether to include a nice header and footer info :param header: str: include a header that starts where `displayF` will start Tuple[str, str]: first one in tree, second one in displayF section :param footer: same thing with header, but at the end :param header: include a header that starts where `displayF` will start :param tabs: number of tabs at the beginning. Best to leave this empty """ # __repr__ if tabs == None: tabs = 5 + self.deepestDepth # __repr__ answer = "ModuleSelector:\n" if intro else "" # __repr__ if header: # __repr__ h1, h2 = ("", header) if isinstance(header, str) else header # __repr__ answer += h1.ljust(tabs*4, " ") + h2 + "\n" # __repr__ answer += f"{self.name}: {self.nn.__class__.__name__}".ljust(tabs*4, " ") # __repr__ answer += self.displayF(self) + ("\n" if len(self._children) > 0 else "") # __repr__ answer += self._children.values() | cli.apply(lambda child: child.__repr__(tabs=tabs-1, intro=False).split("\n")) | cli.joinStreams() | cli.tab() | cli.join("\n") # __repr__ if footer: # __repr__ f1, f2 = ("", footer) if isinstance(footer, str) else footer # __repr__ answer += "\n" + f1.ljust(tabs*4, " ") + f2 # __repr__ if intro: answer += f"""\n\nCan... - mS.deepestDepth: get deepest depth possible - mS.nn: get the underlying nn.Module object - mS.apply(f): apply to self and all descendants - "HookModule" in mS: whether this module has a specified prop - mS.highlight(prop): highlights all modules with specified prop - mS.parse([..., ...]): parses extra css - mS.directParams: get Dict[str, nn.Parameter] that are directly under this module""" # __repr__ return answer # __repr__ def _strTensor(t): return "None" if t is None else f"{t.shape}" # _strTensor def strTensorTuple(ts): # strTensorTuple if len(ts) > 1: # strTensorTuple shapes = "\n".join(f"- {_strTensor(t)}" for t in ts) # strTensorTuple return f"tensors ({len(ts)} total) shapes:\n{shapes}" # strTensorTuple else: # strTensorTuple return f"tensor shape: {_strTensor(ts[0])}" # strTensorTuple @k1lib.patch(ModuleSelector) # strTensorTuple @contextmanager # strTensorTuple def hookF(self, f:Callable[[ModuleSelector, "torch.nn.Module", Tuple[torch.Tensor], torch.Tensor], None]=None, prop:str="*"): # hookF """Context manager for applying forward hooks. Example:: def f(mS, i, o): print(i, o) m = nn.Linear(3, 4) with m.select().hookF(f): m(torch.randn(2, 3)) :param f: hook callback, should accept :class:`ModuleSelector`, inputs and output :param prop: filter property of module to hook onto. If not specified, then it will print out input and output tensor shapes.""" # hookF if f is None: f = lambda mS, i, o: print(f"Forward hook {m}:\n" + ([f"Input {strTensorTuple(i)}", f"Output tensor shape: {o.shape}"] | cli.tab() | cli.join("\n"))) # hookF g = lambda m, i, o: f(self, i, o) # hookF handles = [m.nn.register_forward_hook(g) for m in self.modules(prop)] # hookF try: yield # hookF finally: # hookF for h in handles: h.remove() # hookF @k1lib.patch(ModuleSelector) # hookF @contextmanager # hookF def hookFp(self, f=None, prop:str="*"): # hookFp """Context manager for applying forward pre hooks. Example:: def f(mS, i): print(i) m = nn.Linear(3, 4) with m.select().hookFp(f): m(torch.randn(2, 3)) :param f: hook callback, should accept :class:`ModuleSelector` and inputs :param prop: filter property of module to hook onto. If not specified, then it will print out input tensor shapes.""" # hookFp if f is None: f = lambda mS, i: print(f"Forward pre hook {m}:\n" + ([f"Input {strTensorTuple(i)}"] | cli.tab() | cli.join("\n"))) # hookFp g = lambda m, i: f(self, i) # hookFp handles = [m.nn.register_forward_pre_hook(g) for m in self.modules(prop)] # hookFp try: yield # hookFp finally: # hookFp for h in handles: h.remove() # hookFp @k1lib.patch(ModuleSelector) # hookFp @contextmanager # hookFp def hookB(self, f=None, prop:str="*"): # hookB """Context manager for applying backward hooks. Example:: def f(mS, i, o): print(i, o) m = nn.Linear(3, 4) with m.select().hookB(f): m(torch.randn(2, 3)).sum().backward() :param f: hook callback, should accept :class:`ModuleSelector`, grad inputs and outputs :param prop: filter property of module to hook onto. If not specified, then it will print out input tensor shapes.""" # hookB if f is None: f = lambda mS, i, o: print(f"Backward hook {m}:\n" + ([f"Input {strTensorTuple(i)}", f"Output {strTensorTuple(o)}"] | cli.tab() | cli.join("\n"))) # hookB g = lambda m, i, o: f(self, i, o) # hookB handles = [m.nn.register_full_backward_hook(g) for m in self.modules(prop)] # hookB try: yield # hookB finally: # hookB for h in handles: h.remove() # hookB from contextlib import ExitStack # hookB @contextmanager # hookB def _intercept(self, value:bool): # _intercept handles = [] # _intercept try: # _intercept data = [] # _intercept f = lambda x: x.detach() if x is not None else None # _intercept for m in self.modules("*"): # _intercept subData1 = []; subData2 = []; data.append([subData1, subData2]) # _intercept handles.append(m.nn.register_forward_hook(lambda _m, i, o: subData1.append([[f(e) for e in i], f(o)]))) # _intercept handles.append(m.nn.register_full_backward_hook(lambda _m, i, o: subData2.append([[f(e) for e in i], [f(e) for e in o]]))) # _intercept yield data # _intercept finally: # _intercept for h in handles: h.remove() # _intercept @k1lib.patch(ModuleSelector) # _intercept def intercept(self): # intercept """Returns a context manager that intercept forward and backward signals to parts of the network. Example:: l = k1lib.Learner.sample() with l.model.select("#lin1").intercept() as d: l.run(2) # returns (1, 2, 600, 2, 1, 32, 1), or (#selected modules, [forward, backward], #steps, [input, output], actual data) d | shape() See also: :meth:`hookF`, :meth:`hookFp`, :meth:`hookB`""" # intercept return _intercept(self, False) # intercept from contextlib import ExitStack # intercept @contextmanager # intercept def _freeze(self, value:bool, prop:str): # _freeze with ExitStack() as stack: # _freeze for m in self.modules(prop): # _freeze stack.enter_context(m.nn.gradContext()) # _freeze m.nn.requires_grad_(value) # _freeze try: yield # _freeze finally: pass # _freeze @k1lib.patch(ModuleSelector) # _freeze def freeze(self, prop:str="*"): # freeze """Returns a context manager that freezes (set requires_grad to False) parts of the network. Example:: l = k1lib.Learner.sample() w = l.model.lin1.lin.weight.clone() # weights before with l.model.select("#lin1").freeze(): l.run(1) # returns True (l.model.lin1.lin.weight == w).all()""" # freeze return _freeze(self, False, prop) # freeze @k1lib.patch(ModuleSelector) # freeze def unfreeze(self, prop:str="*"): # unfreeze """Returns a context manager that unfreezes (set requires_grad to True) parts of the network. Example:: l = k1lib.Learner.sample() w = l.model.lin1.lin.weight.clone() # weights before with l.model.select("#lin1").freeze(): with l.model.select("#lin1 > #lin").unfreeze(): l.run(1) # returns False (l.model.lin1.lin.weight == w).all()""" # unfreeze return _freeze(self, True, prop) # unfreeze class CutOff(nn.Module): # CutOff def __init__(self, net, m): # CutOff super().__init__() # CutOff self.net = net; self.m = m; self._lastOutput = None # CutOff def f(m, i, o): self._lastOutput = o # CutOff self.handle = self.m.register_forward_hook(f) # CutOff def forward(self, *args, **kwargs): # CutOff self._lastOutput = None # CutOff self.net(*args, **kwargs) # CutOff return self._lastOutput # CutOff def __del__(self): self.handle.remove() # CutOff @k1lib.patch(ModuleSelector) # CutOff def cutOff(self) -> nn.Module: # cutOff """Creates a new network that returns the selected layer's output. Example:: xb = torch.randn(10, 2) m = nn.Sequential(nn.Linear(2, 5), nn.Linear(5, 4), nn.Linear(4, 6)) m0 = m.select("#0").cutOff(); m1 = m.select("#1").cutOff() # returns (10, 6) m(xb).shape # returns (10, 5) m0(xb).shape == torch.Size([10, 5]) # returns (10, 4) m1(xb).shape == torch.Size([10, 4])""" # cutOff return CutOff(self.nn, self.modules("*") | cli.item() | cli.op().nn) # cutOff