Practice: Apply the chain rule

  1. Define a custom expression node Derivative(expr, v) that symbolically represents taking a derivative of an expression expr with respect to variable v.
  2. Now suppose that, in order to take a derivative by a coordinate x (given), what your code actually has to do is consider the derivative in a reference coordinate system consisting of coordinates r and s and therefore needs to apply the chain rule identity
$$ \frac{d\text{expr}}{dx} = \frac{d\text{expr}}{dr}\frac{dr}{dx} + \frac{d\text{expr}}{ds}\frac{ds}{dx}$$

Write a ChainRuleMapper that applies this identity.

In [7]:
from pymbolic import var
from pymbolic.primitives import Expression
from pymbolic.mapper import IdentityMapper

x = var("x")
r = var("r")
s = var("s")
In [8]:
class Derivative(Expression):
    # ...
    pass

To avoid conflicts with a Derivative node type that's already part of pymbolic, we call our mapper method map_deriv.

In [32]:
# Solution

class Derivative(Expression):

    def __init__(self, expr, v):

        self.expr = expr

        self.v = v

    

    def __getinitargs__(self):

        return (self.expr, self.v)

    

    mapper_method = "map_deriv"
In [33]:
expr = var("sqrt")(Derivative(27*x**2+var("exp")(x), x))
print(repr(expr))
Call(Variable('sqrt'), (Derivative(Sum((Product((27, Power(Variable('x'), 2))), Call(Variable('exp'), (Variable('x'),)))), Variable('x')),))
In [34]:
class ChainRuleMapper(IdentityMapper):
    # ...
    pass
In [37]:
# Solution

class ChainRuleMapper(IdentityMapper):

    def map_deriv(self, expr):

        return sum(Derivative(expr, ref_sym)*Derivative(ref_sym, x) for ref_sym in [r,s])

Now let's test this mapper:

In [38]:
crm = ChainRuleMapper()
crm(expr)
Out[38]:
Call(Variable('sqrt'), (Sum((Product((Derivative(Derivative(Sum((Product((27, Power(Variable('x'), 2))), Call(Variable('exp'), (Variable('x'),)))), Variable('x')), Variable('r')), Derivative(Variable('r'), Variable('x')))), Product((Derivative(Derivative(Sum((Product((27, Power(Variable('x'), 2))), Call(Variable('exp'), (Variable('x'),)))), Variable('x')), Variable('s')), Derivative(Variable('s'), Variable('x')))))),))

In case you are wondering why we can only use the 'clumsy', parenthesis-heavy form of the printed expression, it's because we haven't told pymbolic how to write out the shorter form. Here's how that can be done:

In [48]:
from pymbolic.mapper.stringifier import StringifyMapper, PREC_PRODUCT

class MyStringifyMapper(StringifyMapper):
    def map_deriv(self, expr, enclosing_prec):
        return "d(%s)/d%s" % (
            self.rec(expr.expr, PREC_PRODUCT), 
            self.rec(expr.v, PREC_PRODUCT))
    
def stringifier(self):
    return MyStringifyMapper

Derivative.stringifier = stringifier
print(crm(expr))
sqrt(d(d((27*x**2 + exp(x)))/dx)/dr*d(r)/dx + d(d((27*x**2 + exp(x)))/dx)/ds*d(s)/dx)
In [ ]: