{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Input space optimization\n", "\n", "In this toy example, we directly optimize point configurations in $\\mathbb{R}^2$. The example is motivated by the toy experiments from the following paper:\n", "\n", "**A Topology Layer for Machine Learning** \n", "R. Brüel-Gabrielsson, B. J. Nelson, A. Dwaraknath, P. Skraba, L. J. Guibas and G. Carlsson \n", "arXiv, 2019 \n", "[PDF](https://arxiv.org/abs/1905.12200)\n", "\n", "**Note**: In all the following examples, we use the $l_1$ norm during Vietoris-Rips persistent homology computation (whereas the paper mentioned above uses $l_2$). " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import matplotlib.pyplot as plt\n", "from torchph.pershom import vr_persistence_l1\n", "\n", "device = \"cuda\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data\n", "\n", "The toy data generated for this example is sampled from a 2D uniform distribution in $[0,1]^2$. In particular, we sample 300 points." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "np.random.seed(1234)\n", "toy_data = np.random.rand(300, 2)\n", "\n", "plt.figure()\n", "plt.plot(toy_data[:, 0], toy_data[:, 1], 'b.', markersize=3)\n", "plt.title('Toy data');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example 1\n", "\n", "**Task**: Optimize for uniform distribution of H0 lifetimes." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 1 | Loss: 15.35\n", "Iteration: 20 | Loss: 5.85\n", "Iteration: 40 | Loss: 3.74\n", "Iteration: 60 | Loss: 2.66\n", "Iteration: 80 | Loss: 2.00\n", "Iteration: 100 | Loss: 1.69\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO2dbaxe1XXnf4ub0qRqXpq4maTYtwbVkYpQm4Q7ULdVcURSGTSCfqCJoaRJh9RqZjJfokqDxUAjUA2dUduhCtPkyoNIIk1gUvXFSp2ilsZKFdutL2ogASlgyAsXokJSwpeoTUhXP5znto8fznOfl7P23mufs37S1b2Pn+Oz19ln7f9ee+199hFVJQiCIOg/55Q2IAiCIMhDCH4QBMFACMEPgiAYCCH4QRAEAyEEPwiCYCC8rLQB09ixY4fu3r27tBlBEARV8eCDD35TVX+07Tu3gr979242NjZKmxEEQVAVIvK1ad9FSicIgmAghOAHQRAMhBD8IAiCgRCCHwRBMBBC8IMgCAZCCH4QBMFACMEPAiNOnoTbb29+B4FH3K7DD4KaOHkSLr8cvvtdOPdceOAB2Lu3tFVBcDYmEb6I3C0iz4rIl6Z8/ysi8vDo54SI/LRFuUHghePHG7H//veb38ePl7MlRhrBNKwi/HuADwMfn/L9V4DLVPV5EbkCWAcuNSo7CIqzb18T2W9F+Pv2lbEj1Ujj5MmmE9u3L0YuNWMi+Kr6ORHZvc33J8Y+ngJ2WpQbBF7Yu7cR19Ki2DbS6GpLpKv6Q4kc/g3AZ9q+EJGDwEGA1dXVnDYFA8Yqet27t7wQphhppOhEgjJkFXwReRuN4P982/equk6T7mFtbS1ethskp2/Ra4qRhpd0VdCdbIIvIj8FHAGuUNVv5So3CLajj9Gr9UjDS7oq6E4WwReRVeCPgXer6mM5ygyCefASvXqfFM2RrvJeB33ARPBF5JPAPmCHiGwCvwX8AICqfgS4BXgd8H9EBOBFVV2zKDsIuuAhei2VVvIksH1LrXnFapXOtTO+fx/wPouygsCaRaNXa6EskVbyJrB9TK15JJ60DYrgKbpchBRCWSKt5E1gvaTW+k4IfpAdb9HlIqQQyhJpJW8C6yG1NgRC8IPseIsuFyGVUC47KbrsSMmjwHp4jqGNWkejbYTgB9mxEM1SjdCTUHYdKXkVWE/UPBptIwQ/yE5X0SzdCL0IpfeRUopOOXdH772OFyUEPygSLXcRzb41wmXxlocfJ0WnXKKj91zHyxCCP3BKR8vL0LdGuCye0kuTpOiUS3T0nut4GULwB06N0XJtjTDlCMpLemmSFJ1yqY7eax0vQwh+YrzP8NcaLdfSCGscQVn4bIpOubaO3iMh+AmpobFbNSLvHdsyWFxTbSMoS59N0SnX0tF7JQQ/IbU09q6NyLpj89B5WF1TbSOoWnzWghJ+Vtq3Q/ATUltjXxZLkfAyKrK6ptrSEEPx2RJ+5sG3Q/ATUltjXxZLkbDuPJate8trqikNUcpnh7C+3sPoKQQ/MTU19mWxFAkrobV4CtVDZ13bMxLLMJT19R5GTyH4PaRmkbASWotoqnRn7SEFkIOhrK/3EESE4PeMPoiEhdB6iKa64iEFkIMhra8vHUSE4C9B6Zn27bAWCc/Xuh0eoqmupBBCj/fT+l55vEY3qKrLn4svvlg9cuKE6iteobqy0vw+caL7+Q4f7n6eFPZZX+sQSHE/rc6X635a18GiZef22ZLX2wawoVN0NSL8BfG+BNEyWhpKSsGKVPfTqs5z3M/SKcXcPlv6ehflHIuTiMjdIvKsiHxpyvciIn8gImdE5GEReatFuSXYGmavrKRZgmjB3r1w6JBdHtziWqFpHLff3vzuI6nupxXW97ON0nWQ4xrHKX29i2IV4d8DfBj4+JTvrwD2jH4uBf5w9Ls6PC5BTIXltdYWCS3DkO7nNErXQe65m9LXuyjSpHwMTiSyG/i0ql7U8t1HgeOq+snR5y8D+1T1G9POt7a2phsbGya2eWYoE0y33w4339xEQisrcNttzSjEExb3Yij3czuGVgferldEHlTVtbbvcuXwzwOeGvu8Ofq3swRfRA4CBwFWV1czmVaW0su0cuE9ErIagfTxfi4qaH2sg+2o6XpzCb60/NtLhhaqug6sQxPhpzYqyIf3ZZJeJqi9RYtDSMVN4u0eWJJL8DeBXWOfdwLPZCq7F/TBCVNGQl3rx8MIxKO4eukIc+HxHliSS/CPAh8QkXtpJmtf2C5/H5yNpRP2oeOYxKJ+PIxAPD4056EjzEnfOzgTwReRTwL7gB0isgn8FvADAKr6EeAYcCVwBvgO8GsW5Q4FKyfsa/RiVT+lc7GW4mo5J5GiI/QaeJTq4HLVh4ngq+q1M75X4L9alDVErJywr9FLX6JQrw/NWXeEngOPEiO9nPURT9pWgJUTWgujdVSy7Pk8pGOssBJXz52g98Aj90gvZ32E4FeChRN6fpDKYv96rxPCJfDcCQ5lU7h5ydk5h+APDCthtI5KvEZ9ntMPsyg9JzGNFLtj1nqPIG/nHIIfLIV1VOI1BeFx5UwfqG1TuNTk6pxD8IOlsI5KvKYgPK6cCc7Ga7DgkRD8YGmsoxKPKQivK2fAbrRQ+6ij6z1Kdf0e6zUEPxEeb3awHB5XzliNFvoy6lj2HqW6fq/1arIffnA2Wzf75pub3yn2f0+1t3zf9qz3dD1bkehtt3UXgLbRQsnzeKrnRbC6/lzn7UpE+AlIPYlUS1RSepTjMcryNlqwOE+Oek7lS6ny/17nFULwE5D6ZqfqUCzP60Fs+7B6YxpWcwsW56k1wIF0iwW8LkIIwU9A6ptdQ1TiQWy9RllWWI0Wup6n1gBni1SLBTwuQgjBT0TKm11DVOJhG4dcUZaXLSZKUWuAM0TMXnFozVBecdhnLJcNlk4PTcPbFhN9pbZOsCQeXnEYTDAEB/a6jYMlQ9liojTe0iO1tt8Q/AJEFLcYnof0XreYqFWQaiB1+01570LwC9CXKC6XqHhd8QA+t5jwHFCU6Iisy0zZflPfuxD8MXI5Y+6INcV15RYVb0P6cbxtMZFCkCx8qERHlKLMlO03dTAYgj8ipzPmjFhTXVdfRinTSNVJ1hhQWPlQCZ9JUWbK9ps6GAzBH5HbGXNFrKmuy2te3WskWnNAYeVDJXwmVZkp1+6nDAatXmK+H7gTWAGOqOodE9+vAh8DXjM65kZVPWZRthVeBawrKR3eItdsvX7dayRac0Bh5UMl5mI8z/9MI2kwqKqdfmgE/AngAuBc4CHgwolj1oH3j/6+EPjqrPNefPHFmpsTJ1QPH25+e6OLbR6v68QJ1Ve8QnVlpfltYdvhw835oPl9+LAf21KcMycefShoB9jQKbpqEeFfApxR1ScBRORe4Grg0fF+BXjV6O9XA88YlGvOIj1rztUGnt/3uiwpIl7PkajVOUstt/ToQ7UsPXVl57SeYN4f4BqaNM7W53cDH5445o3AF4FN4Hng4innOghsABurq6vJe8JlyR2tWUWu85IjmktVh32ORGsaJaS+D6nrwsr+EveMxBG+tPUjE5+vBe5R1d8Vkb3AJ0TkIlX9l4nOZ50m/cPa2prPPR/In4/NOb+Qa3Ix5X5AxaOoRNSyMiqHD9WyFt7bPbN4AcomsGvs805emrK5Afj/AKp6Eng5sMOg7CJsCfDKSp4J3i1xtHhxxizaHDQVe/fCoUM+RcsjqfzO+uUlOXwoZRu0tD+3VszCIsI/DewRkfOBp4EDwHUTx3wduBy4R0R+kkbwnzMouwilVhvkKKevq5W2qHl9fQq/q+3BpC1qWQvvbpXQtFzPIj/AlcBjNKt1bhr9263AVfrvK3M+T7OC5wvAL846Z4lVOkGDlzy4tR2x+ualpJof8uJDy1Kz/STO4aPNmvpjE/92y9jfjwI/Z1FWzeSKBLuW4yEPniLy7MP6ems8P5hkub32oufx0AZSEE/aZiLXZGipjbNq2KAqhbjVngJzl3IYYeXHnjeSK0EIfiZyRYIlIs5a8sCe19dbs0gH7DGatfLjXO3B1Vr7bQjBz0SuSLBExJmiUdW0bLPrOb1uMVESKz/O0R5qqu8Q/EzkigRLRJye88CW1LLNdOo16jWtSMrRHmqaxwnBz0guAcstlF7TGpbUtM10qg641ncgpG4PNc3jhOA7pJbVPOOUjsZT11mqKK6WuQqoK5LNSU0BTwj+FEpNwvR9Nc9W2bXlrD1vMz3tvLXUQRe8TJaWDnjmJQS/hZJi2OfVPFBfznqLlFFcLWJhUQeWAm3tSyU7j1xlh+C3UHLo2ufVPOA/Z71dw1tGmL1EoLOY184unZO1QFv6UukRb66yQ/BbKDl07fNqHvCds04RMdawXC+XndadvaUvlQzycpYdgt+CpYAsc46+rubZKtNrztq64aVoyClGDLkEx7qzt/SlkkFe1rKnbbJT+qf2zdNq31RriFjfM+/nS33eaWV53ZSspG2WZZN687TgpfRxCVuJfHTOMq1HH9bnS+VTOdN7nieorWzzvFlbCH4iPC5h60KJfHSJMq0bnuX5UvpUVztTLLWtYbJ7Eu/zNiH4iSg1KZqqoZQYsVivwqhRQMbx+oDPUCe72/A+sg/BT0ju4WvKhlJixGJVZs0CMsmyPpWyw/M+2Z2zs/c+sg/BXwDvUWLK6KJEdGlVZs0CYkHqDs9a5Kyfq8i9/4/HUdgWIfhzUkOUmDq6KLWM09uDWSn8oKYIfBLPk90lUiyeJ6ZD8Ockh+NYvJrQc3RRCu8CUlsE3obXyW7vKZbcmAi+iOwH7gRWgCOqekfLMe8EPgQo8JCqXmdRdi5SO45Vo/cUXXhKfXgWkNoi8JoY8rW30VnwRWQFuAt4B7AJnBaRo9q8uHzrmD3AIeDnVPV5EXl913Jzk9pxSgw9UwpyDSmwZUjhB54jcE+d9rJ4CoJKYxHhXwKcUdUnAUTkXuBq4NGxY34duEtVnwdQ1WcNys1OSsfJPfRMLcipOzDrXRcXOVeK9IXHKDSlj/ShI6kRC8E/D3hq7PMmcOnEMW8CEJHP06R9PqSqfzF5IhE5CBwEWF1dXdqgGp0pd6NPLcgpOzBLIfIyElmmE6n1pS9e6nw7atSQebAQfGn5N20pZw+wD9gJ/I2IXKSq3z7rP6muA+sAa2trk+eYixqcaRo5h545VvSk6sAshShlx1d7yiyVj3h/OKlmDZmFheBvArvGPu8Enmk55pSqfg/4ioh8maYDOG1Q/ll4dyYv5BhRdOnAthNLSyFKJWq1p8wgnY94XznTZw2xEPzTwB4ROR94GjgATK7A+VPgWuAeEdlBk+J50qDsl+DdmTyRSpC7MkssLYUolajVnDIbJ8Wo06LOU/pfrzVk2jaai/wAVwKPAU8AN43+7VbgqtHfAvwezUTuF4EDs87ZZXtkz1uw9oHU2+kePtycG5rfhw/bnj8HObYcHqqfR91uD6m3R1bVY8CxiX+7ZexvBT44+klO6WVYqSd8Sk8o9SV6TUnNKTPv5Epn1VYv8xBP2hqTOnfrYUKp5gnfnHgVjdqXW/YhICjFoAW/xtfFeZhQ8h69blHzHu017q2TKxjpS0BQgsEKfirnTB19eIluvEavW9S8R3ute+vkDEaG/ORwFwYr+KmcM3X0UXt0k6vBWd/fnGKWuqyhLrf0kA4tzWAF39Pr4ko/2p+LnA3O8x7tHspaxIfm9c+uHUmtTw7XxGAF30uk7C3qqDF33Ib1/c3pL13KSjFvsYh/dkm11PrkcE0MVvDBR6TsKeqoNXc8Dev7m8Jfpgn0MmWluH+5/LPmJ4drYtCCvygpol9PUUetueNasRboFPcvl3/W9uRwrZO/Ifhzkir6tRbBLo7oLXfcRoqGZnXORc9jLdAp7l+uTrqmYMBbGnYhpj2CW/qny9YKKajhcX+LR867PlKe8pH0FI/UW51zmfOkup5atwSwIMf1e9cCUm+tMAQ8pV6mYRExdn1cv7YdIq3Oucx5FolqF1kpU020aUyuyLsGLZhGCP6c1DDkLO2IqecAUlyf1TmnnWeWUM8j0CmErNYc9HbkmmCuQQumEYK/AKnX13eltCPWuMeO1TnbzmMl1NZClvup4Vz+mDPgqXUkFYKfiFITOyUdMUeHk+L6rM45eR4robYWslyRcO424P3BLw+E4CfC0/r6nHgfBeXESqitO9JckXCJNuD5wS8PhOAnonQ+vQZSNTIvSzcthdpyZJMr9WfVBnIEBUMJ0AYn+LkiSstG5SEKrmUr6VQTnMuec95J2a5163E/Jos2ECtvbBmU4JfIKXoUMC82pGhknpdutmFRtx58ZBpd20CtK288BGltnFPagJy0OY93PNicyoatRnbbbfadyMqK/dJNy3NuYVG3OX3k5Em4/fbmdw5S1v0ke/fCoUN2QdrNNze/c9XVPJhE+CKyH7gTWAGOqOodU467BvgU8B9VdcOi7EWocdjmIQ/qaSvpec7ndelmGxZ1m8uvS4wkSi81XgbX8wHTHsGd94dG5J8ALgDOBR4CLmw57pXA54BTwNqs86baWsHi0evcj69bbHdQesuFIbFoXdXik963FPBCii0zFoFttlawEPy9wP1jnw8Bh1qO+9/AfwKOlxT8rpS+mctQqqEOsZOo0T/mpc/XZk1J399O8C1SOucBT4193gQuHT9ARN4C7FLVT4vIb047kYgcBA4CrK6uGphmj+vh2hRKpLI8TyROYjnBlmrlkYeURo3plVJ15/VJXAvBl5Z/03/7UuQc4PeB9846kaquA+sAa2trOuPwInQVzxIOWKKhpuoYS7/RaRbWnaulfRZ151XI2qgp6MiFheBvArvGPu8Enhn7/ErgIuC4iAC8ATgqIldpgYnbrnQRz5IOmLuhphhV1PBGJ+vO1co+j+JXyztsvYywLLAQ/NPAHhE5H3gaOABct/Wlqr4A7Nj6LCLHgd+sUey3WFY8a0gHWTl3ilFFivpL9dIQq/tqZZ8336vlHbYeO8oudBZ8VX1RRD4A3E+zYuduVX1ERG6lmTw42rWMvpAql275xiZL57YeVXh7o1OOyM+q4/S2JDlHB2RRd946ys5Mm80t/eN1lU5XrGfvLVdO1LDszsvKnxpXrHipuy1baqi/Wuwch3jjlR+so17LCMRbFNiGl0nDGvPDXuoO6lnxU4ud8xKCXzmWIt0n504tpJEf7o6nDmg7arFzHkLwK8dapEs5t6VA51jKGPnh/PRptUwpQvB7QO0RiHWkm2spY9d6txolpBBBb+I69NGQFSH4C5KrIaR4wMirMFhHurUsZew6Skj5Ahlv4hqjIRtC8BcgV0OwLse7MFhPFte0lLHLKCGVCHoUVw+7xvaBEPwFyNUQrMvxLgwpJost0lzL2pVLVFJ1SB5TTRY+4nHkkptBCX5XJ8y1bNG6HM/CsIXXeYhF7copKqlWVXlNNXX1EY8jl9wMRvAtnDDXssUUK288CkMfyS0qqTpKj6mmrtTwnElqBiP4lumHHM5rXY5HYVgWz3nYEBW/dRAByoAEv6QTehao2vCeh7UUlVr9ZtE6GOrTxiUYjOCX6t29C9Q0UjRCj0s4U2AhKrX6zRbz1kHt11kbgxF8KNO75xAo7y8FsTyn13SBNZZ+43mkUEMH3icGJfglSC1QNbwUxPKcQ8nDWq47t36mw7Luh9KBeyEEf4wUkVBqgUohzikaYaklnLU8sTyJld9YjxSsg4uhdOBumLZvcumf3Pvh17jvtWo6u1PsnZ57P3bruqnRR4bwvgRP+/x7gNgPfza15hJTrrG3vv7ccyi1PLGcEkv/8Jh+iUnfxQjBH+HRmedl6EvNpuH5ieUalyJ6TL9Yd8KeJ7gtCMEf4dGZg254fWK55qjUW3Bh3QnXel/mxUTwRWQ/cCfNS8yPqOodE99/EHgf8CLwHPCfVfVrFmVb0tWZa4gOPG+TnAKPTyzXmBpahppe8g4DuS/Tkvvz/tCI/BPABcC5wEPAhRPHvA34odHf7wfum3Xe2l5iXsOEXsoJ3lTX3scJuRT15a2eamgPk3S12cs9IPGk7SXAGVV9EkBE7gWuBh4d61Q+O3b8KeB6g3JdkSI6sI6QUkUwqc7b1yG2darJYz3VGC13uS8e70EbFoJ/HvDU2OdN4NJtjr8B+EzbFyJyEDgIsLq6amBaPqwnCFM4UA3bJI+TWzRqnEgF21c6Wl1/rYsglr0vtXRwFoIvLf+mrQeKXA+sAZe1fa+q68A6wNraWus5vGIdtaVwoJRLOFOcN6do1BKhtWH1whLL6+/qE17nhKZRSwdnIfibwK6xzzuBZyYPEpG3AzcBl6nqPxuU6w7LqC2VA6VaZZFq3X6ulVO1RGhtWNRTqgBj2ZFGbZ1vLav8LAT/NLBHRM4HngYOANeNHyAibwE+CuxX1WcNyuw9tThQanItA6wlQptG13rydP0eU1Tz4G3JahudBV9VXxSRDwD306zYuVtVHxGRW2lmi48C/wv4YeBTIgLwdVW9qmvZfacGB1qGlA1x2XNbrrGvsZP2FGB4TFH1BZN1+Kp6DDg28W+3jP39dotygtl4F5yUDbHruS2ew6hZZLwEGF5TVH0gnrTtEanF1PvDLaUbuUX53jvsXHhJUfXtfoTgd8CbM9SwHj5lrrh0Hrpr+bWPEDxhMUro4/0IwV8Sj85Qw3r4lLni0nnoruVbd9jeAhKo61mH0iPGFITgL4mlM1g1glrWw6fMFZfOQ3cpv+8bgXm0aTtKjxhTEIK/JF5fQVf7evjt8PpidSss69ljdOrRpu3w4veWhOAviZUz1NIISkfOnl+sbolVPXuMTj3aNIvSfm9NCH4HLJyhxkZQghQdYy2d7TJ4jE492jQ0QvAL07dGkCpF4v3F6h7xGJ16tGlISLN9sj/W1tZ0Y2OjtBnBAqROkfQ9hx/ko8/3XUQeVNW1tu8iwp9Cnx0iFalTJKkmpEuvrgry4nHuJhch+C0M2SG64ClFklqMS/pIjmvrc0fW57mbWYTgt+DZITw3Ri/zETnEuJSP5Eib9T3Y8RSY5CYEvwWvDpGrMXbpVDxMyuUQ41I+kvraSnZkOZ/A9RCYlCAEv4WUDtHFsXM0xj5EeDnEuIuPdPGB1NdWoiMr4XMeApMShOBPIYVDdHXsHI0xd4SXIrLLFcEt4yMWWzinvLYS0a/nFGrfCMHPSFfHztEY+/Ie2WU77NSpBQtxW+Talrme3NGv1xRqHwnBz4iFY6dujDkjPG+RXY7UQl86VEuGnFPPTQh+Rko79rzRXq4Iz1tkl6MD6luHarnTawh9ekLwM1PKsT1Ge6U7wElydUB96VD7/CxCXzERfBHZD9xJ8xLzI6p6x8T3Pwh8HLgY+BbwLlX9qkXZwXx4S59s4SnX7q0D6krq6+nrswh9prPgi8gKcBfwDmATOC0iR1X10bHDbgCeV9WfEJEDwO8A7+patie8Rxw5otdcdeBxstcrKa+nr88i9BmLCP8S4IyqPgkgIvcCVwPjgn818KHR338EfFhERL3u3LYgNUQcqaO9nHUwxAbvMaAoNSLyNvdTExaCfx7w1NjnTeDSaceo6osi8gLwOuCb4weJyEHgIMDq6qqBaXmoRYBSRns562BoDd5zQFFiRNS31FtOLARfWv5tMnKf5xhUdR1Yh2Z75O6m5WFoAtRGzjrw1OBzRN45O1OPI4k2+pZ6y4WF4G8Cu8Y+7wSemXLMpoi8DHg18I8GZbvAkwCVIncdeGjwuSLvXJ2p55FEYIOF4J8G9ojI+cDTwAHguoljjgLvAU4C1wB/3Zf8/RYlBchLVOZBhLfDup5yRd65OtNaUpPB8nQW/FFO/gPA/TTLMu9W1UdE5FZgQ1WPAv8X+ISInKGJ7A90LTdo6GNUlurNVtb1lDuNlfq+9jE16SUY8oLJOnxVPQYcm/i3W8b+/ifgly3KCs6mb1FZqg4sRT0tG3l7FaEcI4mc197HYKgr8aRt5fQtKkvVgaWqp0Ujb+/v/U05ksgtwH0LhiwIwa+cvk0YpxRmD/WUUoS8R7S5BbhvwZAFIfg9wPtk6SKkFGYP9ZRShLxHtLkF2Esn7wnxulhmbW1NNzY2SpsxSLzmmPtCqvqt4RWYnn3Ls22LICIPqupa63ch+L7J7YQpRCP1NfSloVqQo649p42WpU/XtZ3gR0rHMSWc0DotkGOSMtX5a+xIUqetvKeNlqWv1zXJOaUNCKbT5oSp2cqzrqzY5FlTX0Oq8291JDff3Pw+edLmvLVj7R9e6Ot1TRIRvmNKrDKwnuhKfQ2pzj+UiG9R+joR2tfrmiRy+M6pMa0wSY05fC853T7c/yAvMWnrgGi49VH6nnnpdIK6iEnbwnhpuKUFrDZKr9uPtFJgTQh+Bjw0XC+dTptdOfdWqanDiydFA2tC8DPgoeF66HQmydkJee3wtsNiIrG2Ti5ISwh+BjysAPDQ6UySsxPy2OHNQ5e0UolOLjoY34TgZ6J0PjhFp9O1cefshCzLqkXUcndyNY6ihkYI/oCw7HQsGnfOkY9VWTWJWu5RXa2jqCERgh8shVXjzjnysSirJlHLnUr0mDYMziYEP1iKoTbu2q47d4daeq4q2J548CpYmlpy2dYM9bqDOkj2pK2IvBa4D9gNfBV4p6o+P3HMm4E/BF4FfB/4bVW9b9a5Q/CDPhKdRZCalE/a3gg8oKp3iMiNo8//feKY7wC/qqqPi8iPAQ+KyP2q+u2OZQdBVdQ04Rv0k67bI18NfGz098eAX5o8QFUfU9XHR38/AzwL/GjHcoOgOkpsdx0E43QV/P+gqt8AGP1+/XYHi8glwLnAE1O+PygiGyKy8dxzz3U0LQh8MZQ91wO/zEzpiMhfAW9o+eqmRQoSkTcCnwDeo6r/0naMqq4D69Dk8Bc5fxB4J1axBKWZKfiq+vZp34nIP4jIG1X1GyNBf3bKca8C/hz4H6p6amlrg6BySj9xHQybrimdo8B7Rn+/B/izyQNE5FzgT4CPq+qnOpYXBEEQLElXwb8DeIeIPA68Y/QZEVkTkSOjY94J/ALwXhH5wujnzR3LDYIgCBYkHrwKgiDoEdutw+8a4QdBEASVEIIfBEEwEELwgyAIBoLbHL6IPPkWZAgAAAScSURBVAd8LXExO4BvJi6jKzXYCHXYGTbaUYOdNdgI9nb+uKq27mbgVvBzICIb0yY3vFCDjVCHnWGjHTXYWYONkNfOSOkEQRAMhBD8IAiCgTB0wV8vbcAc1GAj1GFn2GhHDXbWYCNktHPQOfwgCIIhMfQIPwiCYDCE4AdBEAyEQQm+iLxWRP5SRB4f/f6RlmPeLCInReQREXlYRN6Vybb9IvJlETkzel3k5Pc/KCL3jb7/WxHZncOuBW38oIg8Oqq3B0Tkx3PbOI+dY8ddIyIqItmX7s1jo4i8c1Sfj4jI/8tt48iGWfd8VUQ+KyJ/P7rvVxaw8W4ReVZEvjTlexGRPxhdw8Mi8laHNv7KyLaHReSEiPx0EkNUdTA/wP8Ebhz9fSPwOy3HvAnYM/r7x4BvAK9JbNcKzVvALqB5I9hDwIUTx/wX4COjvw8A92Wuu3lsfBvwQ6O/35/bxnntHB33SuBzwClgzZuNwB7g74EfGX1+vce6pJlwfP/o7wuBrxaw8xeAtwJfmvL9lcBnAAF+Bvhbhzb+7Ni9viKVjYOK8PH7Dt5LgDOq+qSqfhe4d2TrOOO2/xFwuYhIYrsWslFVP6uq3xl9PAXszGjfFvPUJcBtNAHAP+U0bsQ8Nv46cJeqPg+gqq0vF0rMPHYq8KrR368GnsloX2OA6ueAf9zmkKtp3seh2ryA6TWjFzZlY5aNqnpi616TsO0MTfBN38FryHnAU2OfN0f/1nqMqr4IvAC8LrFdreWPaLNxnBtooqrczLRTRN4C7FLVT+c0bIx56vJNwJtE5PMickpE9mez7t+Zx84PAdeLyCZwDPhveUxbiEV9tzTJ2s7MVxzWRs538BrSFqlPrped55iUzF2+iFwPrAGXJbWonW3tFJFzgN8H3pvLoBbmqcuX0aR19tFEe38jIhep6rcT2zbOPHZeC9yjqr8rInuBT4zsTN1mFqF025kbEXkbjeD/fIrz907wtc538G4Cu8Y+7+SlQ+OtYzZF5GU0w+fthrHWzGMjIvJ2ms71MlX950y2jTPLzlcCFwHHRxmxNwBHReQqVc31xp157/cpVf0e8BUR+TJNB3A6j4n/ZsMsO28A9gOo6kkReTnNZmAlUlDTmMt3SyMiPwUcAa5Q1W+lKGNoKR2v7+A9DewRkfNH5R8Y2TrOuO3XAH+toxkeLzaOUiUfBa4qlHOGGXaq6guqukNVd6vqbpp8aU6xn2njiD+lmQRHRHbQpHiezGgjzGfn14HLAUTkJ4GXA89ltXI2R4FfHa3W+Rngha3UrhdEZBX4Y+DdqvpYsoJyz1aX/KHJeT8APD76/drRv68BR0Z/Xw98D/jC2M+bM9h2JfAYzXzBTaN/u5VGjKBpSJ8CzgB/B1xQoP5m2fhXwD+M1dvRQvd5Wzsnjj1O5lU6c9alAL8HPAp8ETjgsS5pVuZ8nmYFzxeAXyxg4ydpVtN9jyaavwH4DeA3xuryrtE1fLHQ/Z5l4xHg+bG2s5HCjthaIQiCYCAMLaUTBEEwWELwgyAIBkIIfhAEwUAIwQ+CIBgIIfhBEAQDIQQ/CIJgIITgB0EQDIR/BQIyzo8tzj3SAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "X = torch.tensor(\n", " toy_data, \n", " device=device, \n", " requires_grad=True)\n", "\n", "opt = torch.optim.Adam([X], lr=0.01)\n", "\n", "for i in range(1,100+1):\n", " pers = vr_persistence_l1(X, 1, 0)\n", " h_0 = pers[0][0]\n", " \n", " lt = h_0[:, 1] # H0 lifetimes\n", " loss = (lt - 0.1).abs().sum()\n", " \n", " if i % 20 == 0 or i == 1:\n", " print('Iteration: {:3d} | Loss: {:.2f}'.format(i, loss.item()))\n", " \n", " opt.zero_grad()\n", " loss.backward()\n", " opt.step() \n", " \n", "X = X.cpu().detach().numpy()\n", "plt.figure()\n", "plt.plot(X[:, 0], X[:, 1], 'b.');\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example 2\n", "\n", "**Task**: Minimize (non-essential) H0 lifetimes (i.e., a slightly modified as in *Brüel-Gabrielsson et al., arXiv 2019*, Fig. 1 top-right)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 1 | Loss: 14.74\n", "Iteration: 20 | Loss: 7.25\n", "Iteration: 40 | Loss: 5.65\n", "Iteration: 60 | Loss: 4.87\n", "Iteration: 80 | Loss: 4.43\n", "Iteration: 100 | Loss: 4.10\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "X = torch.tensor(\n", " toy_data, \n", " device=device, \n", " requires_grad=True)\n", "\n", "opt = torch.optim.Adam([X], lr=0.01)\n", "\n", "for i in range(1,100+1):\n", " pers = vr_persistence_l1(X, 1, 0)\n", " h_0 = pers[0][0]\n", " \n", " lt = h_0[:, 1] # non-essential H0 lifetimes\n", " loss = lt.sum()\n", " \n", " if i % 20 == 0 or i == 1:\n", " print('Iteration: {:3d} | Loss: {:.2f}'.format(i, loss.item()))\n", " \n", " opt.zero_grad()\n", " loss.backward()\n", " opt.step() \n", " \n", "X = X.cpu().detach().numpy()\n", "plt.figure()\n", "plt.plot(X[:, 0], X[:, 1], 'b.');\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example 3\n", "\n", "**Task**: Increase (non-essential) H0 lifetimes (i.e., a slightly modified version as in *Brüel-Gabrielsson et al., arXiv 2019*, Fig. 1 top-left)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 1 | Loss: -14.74\n", "Iteration: 20 | Loss: -26.88\n", "Iteration: 40 | Loss: -33.22\n", "Iteration: 60 | Loss: -38.43\n", "Iteration: 80 | Loss: -43.28\n", "Iteration: 100 | Loss: -47.87\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "X = torch.tensor(\n", " toy_data, \n", " device=device, \n", " requires_grad=True)\n", "\n", "opt = torch.optim.Adam([X], lr=0.01)\n", "\n", "for i in range(1,100+1):\n", " pers = vr_persistence_l1(X, 1, 0)\n", " h_0 = pers[0][0]\n", " \n", " lt = -h_0[:, 1] # non-essential H0 lifetimes\n", " loss = lt.sum()\n", " \n", " if i % 20 == 0 or i == 1:\n", " print('Iteration: {:3d} | Loss: {:.2f}'.format(i, loss.item()))\n", " \n", " opt.zero_grad()\n", " loss.backward()\n", " opt.step() \n", " \n", "X = X.cpu().detach().numpy()\n", "plt.figure()\n", "plt.plot(X[:, 0], X[:, 1], 'b.');\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 2 }