{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install matplotlib -q" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "def plot_reward_functions():\n", " # Generate retry counts from 0 to 15\n", " retries = np.linspace(0, 15, 100)\n", " \n", " # 1. Basic Sigmoid\n", " basic_sigmoid = 1 / (1 + np.exp(-(retries - 4)))\n", " \n", " # 2. Our Modified Sigmoid\n", " x = retries - 4 # Center at 4 retries\n", " modified_sigmoid = 1 / (1 + np.exp(-x + abs(x)/2))\n", " \n", " # 3. With Penalty\n", " penalized_reward = modified_sigmoid.copy()\n", " for i, r in enumerate(retries):\n", " if r > 6:\n", " penalty = 0.2 * (r - 6)\n", " penalized_reward[i] = max(0.1, modified_sigmoid[i] - penalty)\n", " \n", " # Plotting\n", " plt.figure(figsize=(12, 6))\n", " \n", " plt.plot(retries, basic_sigmoid, 'b--', label='Basic Sigmoid')\n", " plt.plot(retries, modified_sigmoid, 'g--', label='Modified Sigmoid')\n", " plt.plot(retries, penalized_reward, 'r-', label='Final Reward (with penalty)', linewidth=2)\n", " \n", " # Add vertical lines for key points\n", " plt.axvline(x=4, color='gray', linestyle=':', alpha=0.5, label='Peak (4 retries)')\n", " plt.axvline(x=6, color='gray', linestyle=':', alpha=0.5, label='Penalty Start (6 retries)')\n", " \n", " plt.grid(True, alpha=0.3)\n", " plt.xlabel('Number of Retries')\n", " plt.ylabel('Reward')\n", " plt.title('Reward Function Visualization')\n", " plt.legend()\n", " plt.ylim(-0.1, 1.1)\n", " \n", " # Add annotations\n", " plt.annotate('Optimal Zone', xy=(4, 0.8), xytext=(4, 0.9),\n", " ha='center', va='bottom',\n", " bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3),\n", " arrowprops=dict(arrowstyle='->'))\n", " \n", " plt.annotate('Penalty Zone', xy=(8, 0.3), xytext=(8, 0.5),\n", " ha='center', va='bottom',\n", " bbox=dict(boxstyle='round,pad=0.5', fc='red', alpha=0.3),\n", " arrowprops=dict(arrowstyle='->'))\n", " \n", " plt.show()\n", "\n", "# Run the visualization\n", "plot_reward_functions()\n", "\n", "# Print reward values for specific retry counts\n", "def print_reward_examples():\n", " retry_examples = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12]\n", " print(\"\\nReward values for different retry counts:\")\n", " print(\"Retries | Reward\")\n", " print(\"-\" * 20)\n", " \n", " for retries in retry_examples:\n", " x = retries - 4\n", " reward = 1 / (1 + np.exp(-x + abs(x)/2))\n", " if retries > 6:\n", " penalty = 0.2 * (retries - 6)\n", " reward = max(0.1, reward - penalty)\n", " print(f\"{retries:7d} | {reward:.3f}\")\n", "\n", "print_reward_examples()" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 }