fork download
  1. #include <iostream>
  2. #include <string>
  3. using namespace std;
  4.  
  5. // Define node types for the expression tree
  6. enum NodeType { CONST, VAR, ADD, SUB, MUL, POW };
  7.  
  8. // Basic Node structure
  9. struct Node {
  10. NodeType type;
  11. double val; // Used if type is CONST
  12. string varName; // Used if type is VAR
  13. Node *left, *right;
  14.  
  15. Node(NodeType t) : type(t), val(0), left(nullptr), right(nullptr) {}
  16. };
  17.  
  18. // --- HELPER FUNCTIONS ---
  19.  
  20. // Create a constant node
  21. Node *createConst(double v) {
  22. Node *n = new Node(CONST);
  23. n->val = v;
  24. return n;
  25. }
  26.  
  27. // Create a variable node
  28. Node *createVar(string name) {
  29. Node *n = new Node(VAR);
  30. n->varName = name;
  31. return n;
  32. }
  33.  
  34. // Create an operator node
  35. Node *createOp(NodeType t, Node *l, Node *r) {
  36. Node *n = new Node(t);
  37. n->left = l;
  38. n->right = r;
  39. return n;
  40. }
  41.  
  42. // DEEP COPY: Recursively clones a tree to prevent memory sharing
  43. Node *copyTree(Node *root) {
  44. if (!root)
  45. return nullptr;
  46. Node *newNode = new Node(root->type);
  47. newNode->val = root->val;
  48. newNode->varName = root->varName;
  49. newNode->left = copyTree(root->left);
  50. newNode->right = copyTree(root->right);
  51. return newNode;
  52. }
  53.  
  54. // Recursive function to free memory
  55. void deleteTree(Node *root) {
  56. if (!root)
  57. return;
  58. deleteTree(root->left);
  59. deleteTree(root->right);
  60. delete root;
  61. }
  62.  
  63. // Print the expression in human-readable format
  64. void printTree(Node *root) {
  65. if (!root)
  66. return;
  67. if (root->type == CONST)
  68. cout << root->val;
  69. else if (root->type == VAR)
  70. cout << root->varName;
  71. else {
  72. cout << "(";
  73. printTree(root->left);
  74. if (root->type == ADD)
  75. cout << " + ";
  76. else if (root->type == SUB)
  77. cout << " - ";
  78. else if (root->type == MUL)
  79. cout << " * ";
  80. else if (root->type == POW)
  81. cout << " ^ ";
  82. printTree(root->right);
  83. cout << ")";
  84. }
  85. }
  86.  
  87. // --- DIFFERENTIATION ALGORITHM ---
  88. Node *derive(Node *n, string var) {
  89. switch (n->type) {
  90. case CONST:
  91. return createConst(0); // (C)' = 0
  92. case VAR:
  93. return createConst(n->varName == var ? 1 : 0); // (x)' = 1, (y)' = 0
  94. case ADD:
  95. return createOp(ADD, derive(n->left, var), derive(n->right, var));
  96. case SUB:
  97. return createOp(SUB, derive(n->left, var), derive(n->right, var));
  98. case MUL: {
  99. // Product Rule: (u*v)' = u'v + uv'
  100. // We use copyTree to ensure the new tree has its own nodes
  101. Node *leftPart = createOp(MUL, derive(n->left, var), copyTree(n->right));
  102. Node *rightPart = createOp(MUL, copyTree(n->left), derive(n->right, var));
  103. return createOp(ADD, leftPart, rightPart);
  104. }
  105. case POW: {
  106. // Power Rule: (u^n)' = n * u^(n-1) * u'
  107. // Assumes n is a constant
  108. double nVal = n->right->val;
  109. Node *nMinus1 = createConst(nVal - 1);
  110. Node *newPow = createOp(POW, copyTree(n->left), nMinus1);
  111. Node *step1 = createOp(MUL, createConst(nVal), newPow);
  112. return createOp(MUL, step1, derive(n->left, var));
  113. }
  114. }
  115. return nullptr;
  116. }
  117.  
  118. // --- SIMPLIFICATION ALGORITHM ---
  119. Node *simplify(Node *n) {
  120. if (!n || n->type == CONST || n->type == VAR)
  121. return n;
  122. // Simplify children first (Post-order traversal)
  123. n->left = simplify(n->left);
  124. n->right = simplify(n->right);
  125. // Simplify ADD (+)
  126. if (n->type == ADD) {
  127. if (n->left->type == CONST && n->left->val == 0)
  128. return n->right; // 0 + x = x
  129. if (n->right->type == CONST && n->right->val == 0)
  130. return n->left; // x + 0 = x
  131. }
  132. // Simplify MUL (*)
  133. else if (n->type == MUL) {
  134. if (n->left->type == CONST && n->left->val == 0)
  135. return createConst(0); // 0 * x = 0
  136. if (n->right->type == CONST && n->right->val == 0)
  137. return createConst(0); // x * 0 = 0
  138. if (n->left->type == CONST && n->left->val == 1)
  139. return n->right; // 1 * x = x
  140. if (n->right->type == CONST && n->right->val == 1)
  141. return n->left; // x * 1 = x
  142. }
  143. // Simplify POW (^)
  144. else if (n->type == POW) {
  145. if (n->right->type == CONST && n->right->val == 1)
  146. return n->left; // x ^ 1 = x
  147. if (n->right->type == CONST && n->right->val == 0)
  148. return createConst(1); // x ^ 0 = 1
  149. }
  150. // Constant Folding: if both sides are numbers, calculate immediately
  151. if (n->left->type == CONST && n->right->type == CONST) {
  152. if (n->type == ADD)
  153. return createConst(n->left->val + n->right->val);
  154. if (n->type == SUB)
  155. return createConst(n->left->val - n->right->val);
  156. if (n->type == MUL)
  157. return createConst(n->left->val * n->right->val);
  158. }
  159. return n;
  160. }
  161.  
  162. // --- MAIN FUNCTION ---
  163. int main() {
  164. // f(x) = x^2 + 3x
  165. Node *x = createVar("x");
  166. Node *expr = createOp(ADD, createOp(POW, x, createConst(2)),
  167. createOp(MUL, createConst(3), copyTree(x)));
  168.  
  169. cout << "Original Expression: ";
  170. printTree(expr);
  171. cout << endl;
  172. // Calculate Derivative
  173. Node *d = derive(expr, "x");
  174. cout << "Raw Derivative: ";
  175. printTree(d);
  176. cout << endl;
  177. // Simplify the derivative
  178. // We run it twice to ensure nested simplifications (like 0 + (3 * 1)) are
  179. // fully resolved
  180. d = simplify(d);
  181. d = simplify(d);
  182. cout << "Simplified Result: ";
  183. printTree(d);
  184. cout << endl;
  185. // Clean up memory
  186. deleteTree(expr);
  187. deleteTree(d);
  188.  
  189. return 0;
  190. }
Success #stdin #stdout 0.01s 5320KB
stdin
Standard input is empty
stdout
Original Expression: ((x ^ 2) + (3 * x))
Raw Derivative:      (((2 * (x ^ 1)) * 1) + ((0 * x) + (3 * 1)))
Simplified Result:   ((2 * x) + 3)