class Node:
def value(self):
pass
def __add__(self, x):
if not isinstance(x, Node):
x = ConstantNode(float(x))
return SumNode(self, x)
def __sub__(self, x):
if not isinstance(x, Node):
x = ConstantNode(float(x))
return DiffNode(self, x)
def __mul__(self, x):
if not isinstance(x, Node):
x = ConstantNode(float(x))
return ProdNode(self, x)
def __radd__(self, x):
return self.__add__(x)
def __rsub__(self, x):
return self.__sub__(x)
def __rmul__(self, x):
return self.__mul__(x)
class ParameterNode(Node):
def __init__(self, param_value):
self.param_value = param_value
def value(self):
return self.param_value
def set_value(self, pv):
self.param_value = pv
class ConstantNode(Node): # a node for a constant value
def __init__(self, const_value):
self.const_value = const_value
def value(self):
return self.const_value
class SumNode(Node): # x + y
def __init__(self, arg1, arg2):
self.arg1 = arg1
self.arg2 = arg2
def value(self):
arg1_value = self.arg1.value()
arg2_value = self.arg2.value()
return arg1_value + arg2_value
class DiffNode(Node): # x - y
def __init__(self, arg1, arg2):
self.arg1 = arg1
self.arg2 = arg2
def value(self):
arg1_value = self.arg1.value()
arg2_value = self.arg2.value()
return arg1_value - arg2_value
class ProdNode(Node): # x * y
def __init__(self, arg1, arg2):
self.arg1 = arg1
self.arg2 = arg2
def value(self):
arg1_value = self.arg1.value()
arg2_value = self.arg2.value()
return arg1_value * arg2_value