You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
114 lines
3.7 KiB
114 lines
3.7 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install matplotlib -q"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"def plot_reward_functions():\n",
|
|
" # Generate retry counts from 0 to 15\n",
|
|
" retries = np.linspace(0, 15, 100)\n",
|
|
" \n",
|
|
" # 1. Basic Sigmoid\n",
|
|
" basic_sigmoid = 1 / (1 + np.exp(-(retries - 4)))\n",
|
|
" \n",
|
|
" # 2. Our Modified Sigmoid\n",
|
|
" x = retries - 4 # Center at 4 retries\n",
|
|
" modified_sigmoid = 1 / (1 + np.exp(-x + abs(x)/2))\n",
|
|
" \n",
|
|
" # 3. With Penalty\n",
|
|
" penalized_reward = modified_sigmoid.copy()\n",
|
|
" for i, r in enumerate(retries):\n",
|
|
" if r > 6:\n",
|
|
" penalty = 0.2 * (r - 6)\n",
|
|
" penalized_reward[i] = max(0.1, modified_sigmoid[i] - penalty)\n",
|
|
" \n",
|
|
" # Plotting\n",
|
|
" plt.figure(figsize=(12, 6))\n",
|
|
" \n",
|
|
" plt.plot(retries, basic_sigmoid, 'b--', label='Basic Sigmoid')\n",
|
|
" plt.plot(retries, modified_sigmoid, 'g--', label='Modified Sigmoid')\n",
|
|
" plt.plot(retries, penalized_reward, 'r-', label='Final Reward (with penalty)', linewidth=2)\n",
|
|
" \n",
|
|
" # Add vertical lines for key points\n",
|
|
" plt.axvline(x=4, color='gray', linestyle=':', alpha=0.5, label='Peak (4 retries)')\n",
|
|
" plt.axvline(x=6, color='gray', linestyle=':', alpha=0.5, label='Penalty Start (6 retries)')\n",
|
|
" \n",
|
|
" plt.grid(True, alpha=0.3)\n",
|
|
" plt.xlabel('Number of Retries')\n",
|
|
" plt.ylabel('Reward')\n",
|
|
" plt.title('Reward Function Visualization')\n",
|
|
" plt.legend()\n",
|
|
" plt.ylim(-0.1, 1.1)\n",
|
|
" \n",
|
|
" # Add annotations\n",
|
|
" plt.annotate('Optimal Zone', xy=(4, 0.8), xytext=(4, 0.9),\n",
|
|
" ha='center', va='bottom',\n",
|
|
" bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3),\n",
|
|
" arrowprops=dict(arrowstyle='->'))\n",
|
|
" \n",
|
|
" plt.annotate('Penalty Zone', xy=(8, 0.3), xytext=(8, 0.5),\n",
|
|
" ha='center', va='bottom',\n",
|
|
" bbox=dict(boxstyle='round,pad=0.5', fc='red', alpha=0.3),\n",
|
|
" arrowprops=dict(arrowstyle='->'))\n",
|
|
" \n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
"# Run the visualization\n",
|
|
"plot_reward_functions()\n",
|
|
"\n",
|
|
"# Print reward values for specific retry counts\n",
|
|
"def print_reward_examples():\n",
|
|
" retry_examples = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12]\n",
|
|
" print(\"\\nReward values for different retry counts:\")\n",
|
|
" print(\"Retries | Reward\")\n",
|
|
" print(\"-\" * 20)\n",
|
|
" \n",
|
|
" for retries in retry_examples:\n",
|
|
" x = retries - 4\n",
|
|
" reward = 1 / (1 + np.exp(-x + abs(x)/2))\n",
|
|
" if retries > 6:\n",
|
|
" penalty = 0.2 * (retries - 6)\n",
|
|
" reward = max(0.1, reward - penalty)\n",
|
|
" print(f\"{retries:7d} | {reward:.3f}\")\n",
|
|
"\n",
|
|
"print_reward_examples()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|