{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install matplotlib -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "def plot_reward_functions():\n",
    "    # Generate retry counts from 0 to 15\n",
    "    retries = np.linspace(0, 15, 100)\n",
    "\n",
    "    # 1. Basic Sigmoid\n",
    "    basic_sigmoid = 1 / (1 + np.exp(-(retries - 4)))\n",
    "\n",
    "    # 2. Our Modified Sigmoid\n",
    "    x = retries - 4  # Center at 4 retries\n",
    "    modified_sigmoid = 1 / (1 + np.exp(-x + abs(x) / 2))\n",
    "\n",
    "    # 3. With Penalty\n",
    "    penalized_reward = modified_sigmoid.copy()\n",
    "    for i, r in enumerate(retries):\n",
    "        if r > 6:\n",
    "            penalty = 0.2 * (r - 6)\n",
    "            penalized_reward[i] = max(0.1, modified_sigmoid[i] - penalty)\n",
    "\n",
    "    # Plotting\n",
    "    plt.figure(figsize=(12, 6))\n",
    "\n",
    "    plt.plot(retries, basic_sigmoid, \"b--\", label=\"Basic Sigmoid\")\n",
    "    plt.plot(retries, modified_sigmoid, \"g--\", label=\"Modified Sigmoid\")\n",
    "    plt.plot(\n",
    "        retries,\n",
    "        penalized_reward,\n",
    "        \"r-\",\n",
    "        label=\"Final Reward (with penalty)\",\n",
    "        linewidth=2,\n",
    "    )\n",
    "\n",
    "    # Add vertical lines for key points\n",
    "    plt.axvline(x=4, color=\"gray\", linestyle=\":\", alpha=0.5, label=\"Peak (4 retries)\")\n",
    "    plt.axvline(\n",
    "        x=6, color=\"gray\", linestyle=\":\", alpha=0.5, label=\"Penalty Start (6 retries)\"\n",
    "    )\n",
    "\n",
    "    plt.grid(True, alpha=0.3)\n",
    "    plt.xlabel(\"Number of Retries\")\n",
    "    plt.ylabel(\"Reward\")\n",
    "    plt.title(\"Reward Function Visualization\")\n",
    "    plt.legend()\n",
    "    plt.ylim(-0.1, 1.1)\n",
    "\n",
    "    # Add annotations\n",
    "    plt.annotate(\n",
    "        \"Optimal Zone\",\n",
    "        xy=(4, 0.8),\n",
    "        xytext=(4, 0.9),\n",
    "        ha=\"center\",\n",
    "        va=\"bottom\",\n",
    "        bbox=dict(boxstyle=\"round,pad=0.5\", fc=\"yellow\", alpha=0.3),\n",
    "        arrowprops=dict(arrowstyle=\"->\"),\n",
    "    )\n",
    "\n",
    "    plt.annotate(\n",
    "        \"Penalty Zone\",\n",
    "        xy=(8, 0.3),\n",
    "        xytext=(8, 0.5),\n",
    "        ha=\"center\",\n",
    "        va=\"bottom\",\n",
    "        bbox=dict(boxstyle=\"round,pad=0.5\", fc=\"red\", alpha=0.3),\n",
    "        arrowprops=dict(arrowstyle=\"->\"),\n",
    "    )\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# Run the visualization\n",
    "plot_reward_functions()\n",
    "\n",
    "\n",
    "# Print reward values for specific retry counts\n",
    "def print_reward_examples():\n",
    "    retry_examples = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12]\n",
    "    print(\"\\nReward values for different retry counts:\")\n",
    "    print(\"Retries | Reward\")\n",
    "    print(\"-\" * 20)\n",
    "\n",
    "    for retries in retry_examples:\n",
    "        x = retries - 4\n",
    "        reward = 1 / (1 + np.exp(-x + abs(x) / 2))\n",
    "        if retries > 6:\n",
    "            penalty = 0.2 * (retries - 6)\n",
    "            reward = max(0.1, reward - penalty)\n",
    "        print(f\"{retries:7d} | {reward:.3f}\")\n",
    "\n",
    "\n",
    "print_reward_examples()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}