#include <iostream>
#include <string>
using namespace std;
// Define node types for the expression tree
enum NodeType { CONST, VAR, ADD, SUB, MUL, POW };
// Basic Node structure
struct Node {
NodeType type;
double val; // Used if type is CONST
string varName; // Used if type is VAR
Node *left, *right;
Node(NodeType t) : type(t), val(0), left(nullptr), right(nullptr) {}
};
// --- HELPER FUNCTIONS ---
// Create a constant node
Node *createConst(double v) {
Node *n = new Node(CONST);
n->val = v;
return n;
}
// Create a variable node
Node *createVar(string name) {
Node *n = new Node(VAR);
n->varName = name;
return n;
}
// Create an operator node
Node *createOp(NodeType t, Node *l, Node *r) {
Node *n = new Node(t);
n->left = l;
n->right = r;
return n;
}
// DEEP COPY: Recursively clones a tree to prevent memory sharing
Node *copyTree(Node *root) {
if (!root)
return nullptr;
Node *newNode = new Node(root->type);
newNode->val = root->val;
newNode->varName = root->varName;
newNode->left = copyTree(root->left);
newNode->right = copyTree(root->right);
return newNode;
}
// Recursive function to free memory
void deleteTree(Node *root) {
if (!root)
return;
deleteTree(root->left);
deleteTree(root->right);
delete root;
}
// Print the expression in human-readable format
void printTree(Node *root) {
if (!root)
return;
if (root->type == CONST)
cout << root->val;
else if (root->type == VAR)
cout << root->varName;
else {
cout << "(";
printTree(root->left);
if (root->type == ADD)
cout << " + ";
else if (root->type == SUB)
cout << " - ";
else if (root->type == MUL)
cout << " * ";
else if (root->type == POW)
cout << " ^ ";
printTree(root->right);
cout << ")";
}
}
// --- DIFFERENTIATION ALGORITHM ---
Node *derive(Node *n, string var) {
switch (n->type) {
case CONST:
return createConst(0); // (C)' = 0
case VAR:
return createConst(n->varName == var ? 1 : 0); // (x)' = 1, (y)' = 0
case ADD:
return createOp(ADD, derive(n->left, var), derive(n->right, var));
case SUB:
return createOp(SUB, derive(n->left, var), derive(n->right, var));
case MUL: {
// Product Rule: (u*v)' = u'v + uv'
// We use copyTree to ensure the new tree has its own nodes
Node *leftPart = createOp(MUL, derive(n->left, var), copyTree(n->right));
Node *rightPart = createOp(MUL, copyTree(n->left), derive(n->right, var));
return createOp(ADD, leftPart, rightPart);
}
case POW: {
// Power Rule: (u^n)' = n * u^(n-1) * u'
// Assumes n is a constant
double nVal = n->right->val;
Node *nMinus1 = createConst(nVal - 1);
Node *newPow = createOp(POW, copyTree(n->left), nMinus1);
Node *step1 = createOp(MUL, createConst(nVal), newPow);
return createOp(MUL, step1, derive(n->left, var));
}
}
return nullptr;
}
// --- SIMPLIFICATION ALGORITHM ---
Node *simplify(Node *n) {
if (!n || n->type == CONST || n->type == VAR)
return n;
// Simplify children first (Post-order traversal)
n->left = simplify(n->left);
n->right = simplify(n->right);
// Simplify ADD (+)
if (n->type == ADD) {
if (n->left->type == CONST && n->left->val == 0)
return n->right; // 0 + x = x
if (n->right->type == CONST && n->right->val == 0)
return n->left; // x + 0 = x
}
// Simplify MUL (*)
else if (n->type == MUL) {
if (n->left->type == CONST && n->left->val == 0)
return createConst(0); // 0 * x = 0
if (n->right->type == CONST && n->right->val == 0)
return createConst(0); // x * 0 = 0
if (n->left->type == CONST && n->left->val == 1)
return n->right; // 1 * x = x
if (n->right->type == CONST && n->right->val == 1)
return n->left; // x * 1 = x
}
// Simplify POW (^)
else if (n->type == POW) {
if (n->right->type == CONST && n->right->val == 1)
return n->left; // x ^ 1 = x
if (n->right->type == CONST && n->right->val == 0)
return createConst(1); // x ^ 0 = 1
}
// Constant Folding: if both sides are numbers, calculate immediately
if (n->left->type == CONST && n->right->type == CONST) {
if (n->type == ADD)
return createConst(n->left->val + n->right->val);
if (n->type == SUB)
return createConst(n->left->val - n->right->val);
if (n->type == MUL)
return createConst(n->left->val * n->right->val);
}
return n;
}
// --- MAIN FUNCTION ---
int main() {
// f(x) = x^2 + 3x
Node *x = createVar("x");
Node *expr = createOp(ADD, createOp(POW, x, createConst(2)),
createOp(MUL, createConst(3), copyTree(x)));
cout << "Original Expression: ";
printTree(expr);
cout << endl;
// Calculate Derivative
Node *d = derive(expr, "x");
cout << "Raw Derivative: ";
printTree(d);
cout << endl;
// Simplify the derivative
// We run it twice to ensure nested simplifications (like 0 + (3 * 1)) are
// fully resolved
d = simplify(d);
d = simplify(d);
cout << "Simplified Result: ";
printTree(d);
cout << endl;
// Clean up memory
deleteTree(expr);
deleteTree(d);
return 0;
}