You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
4.0 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install matplotlib -q"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"def plot_reward_functions():\n",
" # Generate retry counts from 0 to 15\n",
" retries = np.linspace(0, 15, 100)\n",
"\n",
" # 1. Basic Sigmoid\n",
" basic_sigmoid = 1 / (1 + np.exp(-(retries - 4)))\n",
"\n",
" # 2. Our Modified Sigmoid\n",
" x = retries - 4 # Center at 4 retries\n",
" modified_sigmoid = 1 / (1 + np.exp(-x + abs(x) / 2))\n",
"\n",
" # 3. With Penalty\n",
" penalized_reward = modified_sigmoid.copy()\n",
" for i, r in enumerate(retries):\n",
" if r > 6:\n",
" penalty = 0.2 * (r - 6)\n",
" penalized_reward[i] = max(0.1, modified_sigmoid[i] - penalty)\n",
"\n",
" # Plotting\n",
" plt.figure(figsize=(12, 6))\n",
"\n",
" plt.plot(retries, basic_sigmoid, \"b--\", label=\"Basic Sigmoid\")\n",
" plt.plot(retries, modified_sigmoid, \"g--\", label=\"Modified Sigmoid\")\n",
" plt.plot(\n",
" retries,\n",
" penalized_reward,\n",
" \"r-\",\n",
" label=\"Final Reward (with penalty)\",\n",
" linewidth=2,\n",
" )\n",
"\n",
" # Add vertical lines for key points\n",
" plt.axvline(x=4, color=\"gray\", linestyle=\":\", alpha=0.5, label=\"Peak (4 retries)\")\n",
" plt.axvline(\n",
" x=6, color=\"gray\", linestyle=\":\", alpha=0.5, label=\"Penalty Start (6 retries)\"\n",
" )\n",
"\n",
" plt.grid(True, alpha=0.3)\n",
" plt.xlabel(\"Number of Retries\")\n",
" plt.ylabel(\"Reward\")\n",
" plt.title(\"Reward Function Visualization\")\n",
" plt.legend()\n",
" plt.ylim(-0.1, 1.1)\n",
"\n",
" # Add annotations\n",
" plt.annotate(\n",
" \"Optimal Zone\",\n",
" xy=(4, 0.8),\n",
" xytext=(4, 0.9),\n",
" ha=\"center\",\n",
" va=\"bottom\",\n",
" bbox=dict(boxstyle=\"round,pad=0.5\", fc=\"yellow\", alpha=0.3),\n",
" arrowprops=dict(arrowstyle=\"->\"),\n",
" )\n",
"\n",
" plt.annotate(\n",
" \"Penalty Zone\",\n",
" xy=(8, 0.3),\n",
" xytext=(8, 0.5),\n",
" ha=\"center\",\n",
" va=\"bottom\",\n",
" bbox=dict(boxstyle=\"round,pad=0.5\", fc=\"red\", alpha=0.3),\n",
" arrowprops=dict(arrowstyle=\"->\"),\n",
" )\n",
"\n",
" plt.show()\n",
"\n",
"\n",
"# Run the visualization\n",
"plot_reward_functions()\n",
"\n",
"\n",
"# Print reward values for specific retry counts\n",
"def print_reward_examples():\n",
" retry_examples = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12]\n",
" print(\"\\nReward values for different retry counts:\")\n",
" print(\"Retries | Reward\")\n",
" print(\"-\" * 20)\n",
"\n",
" for retries in retry_examples:\n",
" x = retries - 4\n",
" reward = 1 / (1 + np.exp(-x + abs(x) / 2))\n",
" if retries > 6:\n",
" penalty = 0.2 * (retries - 6)\n",
" reward = max(0.1, reward - penalty)\n",
" print(f\"{retries:7d} | {reward:.3f}\")\n",
"\n",
"\n",
"print_reward_examples()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}