# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""
This module allows you to make and combine a bunch of schedules, and setup the
optimizer so that it changes hyperparameter values based on the schedule. Highly
recommend you check out the `tutorials section <tutorials.html>`_ on this.
This is exposed automatically with::
from k1lib.imports import *
schedule.Fn # exposed
"""
import math, k1lib; import k1lib.cli as cli
plt = k1lib.dep.plt
import numpy as np
from itertools import accumulate
from k1lib.callbacks import Cbs, Callback
from typing import List, Callable, Union
__all__ = ["Fn", "linear", "smooth", "hump", "exp", "ParamScheduler"]
[docs]class Fn: # Fn
[docs] def __init__(self, f:Callable[[float], float], param:str=None): # Fn
"""Creates a new schedule based on some custom function.
Example::
s = schedule.Fn(lambda x: x**2)
s(0.2) # returns 0.04
# you can also use this as a decorator
@schedule.Fn
def s(x):
return x**2
:param f: function (domain should always in [0, 1]), can be :class:`~k1lib.cli.modifier.op`
:param param: (optional) Parameter to schedule (e.g "lr") if using :class:`ParamScheduler`""" # Fn
if isinstance(f, cli.op): f.op_solidify() # Fn
self.f = f; self.param = param; self.progress = None # Fn
self.domain = k1lib.Range(0, 1) # Fn
def __call__(self, x:float): # Fn
"""Get the current value.""" # Fn
return self.f(x) # Fn
def _startBatch(self, paramGroup:dict, progress:float): # Fn
self.progress = progress # Fn
paramGroup[self.param] = self(progress) # Fn
@property # Fn
def value(self): return self.f(self.progress) # Fn
def __mul__(self, x): self.domain *= x; return self # Fn
def __rmul__(self, x): self.domain *= x; return self # Fn
def __truediv__(self, x): self.domain /= x; return self # Fn
def __rtruediv__(self, x): self * (1.0/x); return self # Fn
def __radd__(self, v): # Fn
if isinstance(v, int): return self # Fn
return NotImplemented # Fn
def __add__(self, s:Union["Fn", str]) -> "Fn": # Fn
"""If given :class:`Fn`, then combines the 2 schedules together.
If it's a string, then sets the current param to it.""" # Fn
if isinstance(s, Fn): return CombinedSchedule(self, s) # Fn
self.param = s; return self # Fn
[docs] def iter(self, n:int): # Fn
"""Returns an n-step iterator evenly divided in range [0, 1].
Example::
s = schedule.Fn(lambda x: x+2)
list(s.iter(6)) # returns [2.0, 2.2, 2.4, 2.6, 2.8, 3.0]""" # Fn
for e in np.linspace(0, 1, n): yield self(e) # Fn
[docs] def modifyOutput(self, f:Callable[[float], float]) -> "Fn": # Fn
"""Returns a new :class:`Fn` that has its output modified.
Example::
s = Fn(lambda x: x+2)
s.modifyOutput(lambda x: x**2) # now s's function is (x+2)**2""" # Fn
return Fn(lambda x: f(self.f(x)), self.param) # Fn
@k1lib.patch(Fn) # Fn
def __repr__(self): # __repr__
plt.figure(dpi=100); c = dict(color="tab:green") # __repr__
x = np.linspace(*self.domain, 1000); y = [self.f(x) for x in x]; plt.plot(x, y) # __repr__
y = self(0); plt.plot(0, y, "o", **c); plt.annotate("(0, {:.1e})".format(y), (0, y)) # __repr__
y = self(1); plt.plot(1, y, "o", **c); plt.annotate("(1, {:.1e})".format(y), (1, y)) # __repr__
x = self.progress # __repr__
if x is not None: # __repr__
blur = not (x in k1lib.Range(0.1, 0.9)) # __repr__
y = self(x); plt.plot(x, y, "o", **c, alpha=(0.5 if blur else 1)) # __repr__
if not blur: plt.annotate("({:.1e}, {:.1e})".format(x, y), (x, y)) # __repr__
plt.show() # __repr__
return f"""'{self.param}' schedule. Can...
- s.progress: to get last recorded progress
- s.value: to get last recorded hyper parameter's value
- s(0.3): to get value of schedule at 30% progress""" # __repr__
class CombinedSchedule(Fn): # CombinedSchedule
def __init__(self, s1, s2): # CombinedSchedule
split = s1.domain.stop / (s1.domain.delta + s2.domain.delta) # CombinedSchedule
s1r = k1lib.Range(0, split); s2r = k1lib.Range(split, 1) # CombinedSchedule
def f(x): # CombinedSchedule
if x < split: return s1.f(s1r.toUnit(x)) # CombinedSchedule
else: return s2.f(s2r.toUnit(x)) # CombinedSchedule
super().__init__(f, s1.param or s2.param) # CombinedSchedule
def decorate(f:Callable[[float, float, float], float]) -> Fn: # decorate
"""Decorator, transforms f(low, high, x) to (low, high) -> f(x).""" # decorate
def _f(low, high, param:str=None): # decorate
return Fn(lambda x: f(low, high, x), param) # decorate
return k1lib.wraps(f)(_f) # decorate
[docs]@decorate # decorate
def linear(low, high, x): # linear
"""Sharply goes from low to high""" # linear
return low + x * (high - low) # linear
[docs]@decorate # linear
def smooth(low, high, x): # smooth
"""Smoothly goes from low to high""" # smooth
return low + (high - low) * (1 + math.cos(math.pi * (1-x))) / 2 # smooth
[docs]def hump(low, high, param:str=None): # hump
"""Smoothly rises up (30%), then down (70%)""" # hump
return 0.3*smooth(0.8 * low + 0.2 * high, high) + 0.7*smooth(high, low, param) # hump
_en4 = math.e**-3 # hump
[docs]@decorate # hump
def exp(low, high, x): # exp
"""Rises/drops quickly, then rate of change gets smaller and smaller""" # exp
return (math.exp(-x*4+1) - _en4) / (math.e - _en4) * (low - high) + high # exp
[docs]@k1lib.patch(Cbs) # exp
class ParamScheduler(Callback): # ParamScheduler
"""Schedules a param in parts of the network.
:param css: the selected parts of the network to schedule
:param schedules: (obvious)""" # ParamScheduler
def __init__(self, css:str, *schedules:List[Fn]): # ParamScheduler
super().__init__(); self.css = css # ParamScheduler
for i, s in enumerate(schedules): # ParamScheduler
if s.param is None: raise RuntimeError(f"Schedule {i} does not have associated parameter! Set with `s.param = 'lr'`.") # ParamScheduler
self.schedules = {s.param:s for s in schedules} # ParamScheduler
self.groupId = None; self.dependsOn = set("ProgressBar") # ParamScheduler
self.initialized = False; self.prop = None # ParamScheduler
def endRun(self): # ParamScheduler
":meta private:" # ParamScheduler
self.initialized = False # ParamScheduler
def __getstate__(self): # ParamScheduler
answer = dict(self.__dict__) # ParamScheduler
if "selector" in answer: del answer["selector"] # ParamScheduler
return answer # ParamScheduler
[docs] def startBatch(self): # ParamScheduler
if self.l.model.training and self.groupId is not None: # ParamScheduler
paramGroup = self.l.opt.param_groups[self.groupId] # ParamScheduler
progress = self.l.progress # ParamScheduler
for schedule in self.schedules.values(): # ParamScheduler
schedule._startBatch(paramGroup, progress) # ParamScheduler
def __repr__(self): # ParamScheduler
print(f"{self._reprHead}, css: \"{self.css}\", selector prop: \"{self.prop}\", schedules:") # ParamScheduler
for schedule in self.schedules.values(): schedule.__repr__() # ParamScheduler
return f"""Can...
- ps.schedules["lr"]: to get the schedule for a specific param
- ps.selector: to view the selected parameters
{self._reprCan}""" # ParamScheduler
@k1lib.patch(ParamScheduler, name="startRun") # ParamScheduler
def _startRun(self): # _startRun
if not self.initialized: # _startRun
# get all other ParamSchedulers # _startRun
pss = [cb for cb in self.l.cbs if isinstance(cb, ParamScheduler) and not cb.suspended] # _startRun
for i, ps in enumerate(pss): # _startRun
# make sure only 1 startRun is ran across all ParamSchedulers # _startRun
ps.initialized = True; ps.prop = f"_ps_{i}" # _startRun
ps.selector = k1lib.selector.select(self.l.model, ps.css) # _startRun
# sort pss based on depth, so that deeper ones gets accounted for first # _startRun
ps._depth = next(ps.selector.modules(ps.prop)).depth # _startRun
pss = sorted(pss, key=lambda ps: -ps._depth) # _startRun
# clear and add param groups # _startRun
self.l.opt.param_groups = [] # _startRun
allParams = set(self.l.selector.nn.parameters()) # _startRun
for ps in pss: # _startRun
params = set() # _startRun
for m in ps.selector.modules(ps.prop): # _startRun
for p in m.nn.parameters(): # _startRun
if p in allParams: # _startRun
params.add(p); allParams.remove(p) # _startRun
if len(params) > 0: # _startRun
# so that we have a way to reference the group later on # _startRun
ps.groupId = len(self.l.opt.param_groups) # _startRun
self.l.opt.add_param_group({"prop": ps.prop, "css": ps.css, "params": list(params), **self.l.opt.defaults}) # _startRun
self.l.opt.add_param_group({"prop": "rest", "css": "*", "params": list(allParams), **self.l.opt.defaults}) # _startRun
for ps in pss: # _startRun
if ps.groupId is None: continue # _startRun
params = set(self.l.opt.param_groups[ps.groupId]["params"]) # _startRun
def applyF(mS): # _startRun
mS.displayF = lambda s: "*" if any(p in params for p in s.directParams.values()) else "" # _startRun
ps.selector.apply(applyF) # _startRun