-
Hello all, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi @Sajedeh1996 , please check possible solution below. As you see, you'll need to define a function wrapper #include <iostream>
#include <autodiff/forward/dual.hpp>
using namespace autodiff;
auto A(double x, double y) -> double { return x*y; }
auto Ax(double x, double y) -> double { return x; }
auto Ay(double x, double y) -> double { return y; }
auto B(double x, double y) -> double { return x + y; }
auto Bx(double x, double y) -> double { return 1.0; }
auto By(double x, double y) -> double { return 1.0; }
auto Adual(dual const& x, dual const& y) -> dual
{
dual res = A(x.val, y.val);
if(x.grad != 0.0)
res.grad += x.grad * Ax(x.val, y.val);
if(y.grad != 0.0)
res.grad += y.grad * Ay(x.val, y.val);
return res;
}
auto Bdual(dual const& x, dual const& y) -> dual
{
dual res = B(x.val, y.val);
if(x.grad != 0.0)
res.grad += x.grad * Bx(x.val, y.val);
if(y.grad != 0.0)
res.grad += y.grad * By(x.val, y.val);
return res;
}
auto C(dual const& x, dual const& y) -> dual
{
const auto A = Adual(x, y);
const auto B = Bdual(x, y);
return A*A + B;
}
int main()
{
dual x = 1.0;
dual y = 2.0;
auto C0 = C(x, y);
auto Cx = derivative(C, wrt(x), at(x, y));
auto Cy = derivative(C, wrt(y), at(x, y));
auto expectedCx = 2.0*A(x.val, y.val)*Ax(x.val, y.val) + Bx(x.val, y.val);
auto expectedCy = 2.0*A(x.val, y.val)*Ay(x.val, y.val) + By(x.val, y.val);
std::cout << "C0 = " << C0 << "\n";
std::cout << "Cx(computed) = " << Cx << "\n";
std::cout << "Cx(expected) = " << expectedCx << "\n";
std::cout << "Cy(computed) = " << Cy << "\n";
std::cout << "Cy(expected) = " << expectedCy << "\n";
}
// Output:
// C0 = 7
// Cx(computed) = 5
// Cx(expected) = 5
// Cy(computed) = 9
// Cy(expected) = 9 This example has now been added to the list of examples and it's available on the website: |
Beta Was this translation helpful? Give feedback.
Hi @Sajedeh1996 , please check possible solution below. As you see, you'll need to define a function wrapper
Adual
andBdual
as shown below, which inspect if incomingx
andy
arguments have been "seeded" in theirgrad
member: