1torch.hub 2=================================== 3Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility. 4 5Publishing models 6----------------- 7 8Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) 9to a GitHub repository by adding a simple ``hubconf.py`` file; 10 11``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function 12(example: a pre-trained model you want to publish). 13 14:: 15 16 def entrypoint_name(*args, **kwargs): 17 # args & kwargs are optional, for models which take positional/keyword arguments. 18 ... 19 20How to implement an entrypoint? 21^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 22Here is a code snippet specifies an entrypoint for ``resnet18`` model if we expand 23the implementation in ``pytorch/vision/hubconf.py``. 24In most case importing the right function in ``hubconf.py`` is sufficient. Here we 25just want to use the expanded version as an example to show how it works. 26You can see the full script in 27`pytorch/vision repo <https://github.com/pytorch/vision/blob/master/hubconf.py>`_ 28 29:: 30 31 dependencies = ['torch'] 32 from torchvision.models.resnet import resnet18 as _resnet18 33 34 # resnet18 is the name of entrypoint 35 def resnet18(pretrained=False, **kwargs): 36 """ # This docstring shows up in hub.help() 37 Resnet18 model 38 pretrained (bool): kwargs, load pretrained weights into the model 39 """ 40 # Call the model, load pretrained weights 41 model = _resnet18(pretrained=pretrained, **kwargs) 42 return model 43 44 45- ``dependencies`` variable is a **list** of package names required to **load** the model. Note this might 46 be slightly different from dependencies required for training a model. 47- ``args`` and ``kwargs`` are passed along to the real callable function. 48- Docstring of the function works as a help message. It explains what does the model do and what 49 are the allowed positional/keyword arguments. It's highly recommended to add a few examples here. 50- Entrypoint function can either return a model(nn.module), or auxiliary tools to make the user workflow smoother, e.g. tokenizers. 51- Callables prefixed with underscore are considered as helper functions which won't show up in :func:`torch.hub.list()`. 52- Pretrained weights can either be stored locally in the GitHub repo, or loadable by 53 :func:`torch.hub.load_state_dict_from_url()`. If less than 2GB, it's recommended to attach it to a `project release <https://help.github.com/en/articles/distributing-large-binaries>`_ 54 and use the url from the release. 55 In the example above ``torchvision.models.resnet.resnet18`` handles ``pretrained``, alternatively you can put the following logic in the entrypoint definition. 56 57:: 58 59 if pretrained: 60 # For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth 61 dirname = os.path.dirname(__file__) 62 checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>) 63 state_dict = torch.load(checkpoint) 64 model.load_state_dict(state_dict) 65 66 # For checkpoint saved elsewhere 67 checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 68 model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False)) 69 70 71Important Notice 72^^^^^^^^^^^^^^^^ 73 74- The published models should be at least in a branch/tag. It can't be a random commit. 75 76 77Loading models from Hub 78----------------------- 79 80Pytorch Hub provides convenient APIs to explore all available models in hub 81through :func:`torch.hub.list()`, show docstring and examples through 82:func:`torch.hub.help()` and load the pre-trained models using 83:func:`torch.hub.load()`. 84 85 86.. automodule:: torch.hub 87 88.. autofunction:: list 89 90.. autofunction:: help 91 92.. autofunction:: load 93 94.. autofunction:: download_url_to_file 95 96.. autofunction:: load_state_dict_from_url 97 98Running a loaded model: 99^^^^^^^^^^^^^^^^^^^^^^^ 100 101Note that ``*args`` and ``**kwargs`` in :func:`torch.hub.load()` are used to 102**instantiate** a model. After you have loaded a model, how can you find out 103what you can do with the model? 104A suggested workflow is 105 106- ``dir(model)`` to see all available methods of the model. 107- ``help(model.foo)`` to check what arguments ``model.foo`` takes to run 108 109To help users explore without referring to documentation back and forth, we strongly 110recommend repo owners make function help messages clear and succinct. It's also helpful 111to include a minimal working example. 112 113Where are my downloaded models saved? 114^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 115 116The locations are used in the order of 117 118- Calling ``hub.set_dir(<PATH_TO_HUB_DIR>)`` 119- ``$TORCH_HOME/hub``, if environment variable ``TORCH_HOME`` is set. 120- ``$XDG_CACHE_HOME/torch/hub``, if environment variable ``XDG_CACHE_HOME`` is set. 121- ``~/.cache/torch/hub`` 122 123.. autofunction:: get_dir 124 125.. autofunction:: set_dir 126 127Caching logic 128^^^^^^^^^^^^^ 129 130By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in the 131directory returned by :func:`~torch.hub.get_dir()`. 132 133Users can force a reload by calling ``hub.load(..., force_reload=True)``. This will delete 134the existing GitHub folder and downloaded weights, reinitialize a fresh download. This is useful 135when updates are published to the same branch, users can keep up with the latest release. 136 137 138Known limitations: 139^^^^^^^^^^^^^^^^^^ 140Torch hub works by importing the package as if it was installed. There are some side effects 141introduced by importing in Python. For example, you can see new items in Python caches 142``sys.modules`` and ``sys.path_importer_cache`` which is normal Python behavior. 143This also means that you may have import errors when importing different models 144from different repos, if the repos have the same sub-package names (typically, a 145``model`` subpackage). A workaround for these kinds of import errors is to 146remove the offending sub-package from the ``sys.modules`` dict; more details can 147be found in `this GitHub issue 148<https://github.com/pytorch/hub/issues/243#issuecomment-942403391>`_. 149 150A known limitation that is worth mentioning here: users **CANNOT** load two different branches of 151the same repo in the **same python process**. It's just like installing two packages with the 152same name in Python, which is not good. Cache might join the party and give you surprises if you 153actually try that. Of course it's totally fine to load them in separate processes. 154