1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
# recursive function to get all parameters
def _get_params(value: ndl.Tensor)->List[Parameter]:
if isinstance(value, Parameter):
return [value]
elif isinstance(value, dict):
result = []
for k, v in value.items():
result += _get_params(v)
return result
if isinstance(value, Module):
return value.parameters()
class Module:
def parameters(self)->List[Parameter]:
return _get_params(self.__dict__) # dict -> self's attributes, will be recursively searched for Parameter instances
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
class ScalarAdd(Module):
def __init__(self, init_s=1, init_b=0):
self.s = Parameter([init_s], dtype="float32")
self.b = Parameter([init_b], dtype="float32")
def forward(self, x):
return self.s * x + self.b
|