xref: /aosp_15_r20/external/pytorch/docs/source/hub.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1torch.hub
2===================================
3Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.
4
5Publishing models
6-----------------
7
8Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights)
9to a GitHub repository by adding a simple ``hubconf.py`` file;
10
11``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function
12(example: a pre-trained model you want to publish).
13
14::
15
16    def entrypoint_name(*args, **kwargs):
17        # args & kwargs are optional, for models which take positional/keyword arguments.
18        ...
19
20How to implement an entrypoint?
21^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
22Here is a code snippet specifies an entrypoint for ``resnet18`` model if we expand
23the implementation in ``pytorch/vision/hubconf.py``.
24In most case importing the right function in ``hubconf.py`` is sufficient. Here we
25just want to use the expanded version as an example to show how it works.
26You can see the full script in
27`pytorch/vision repo <https://github.com/pytorch/vision/blob/master/hubconf.py>`_
28
29::
30
31    dependencies = ['torch']
32    from torchvision.models.resnet import resnet18 as _resnet18
33
34    # resnet18 is the name of entrypoint
35    def resnet18(pretrained=False, **kwargs):
36        """ # This docstring shows up in hub.help()
37        Resnet18 model
38        pretrained (bool): kwargs, load pretrained weights into the model
39        """
40        # Call the model, load pretrained weights
41        model = _resnet18(pretrained=pretrained, **kwargs)
42        return model
43
44
45- ``dependencies`` variable is a **list** of package names required to **load** the model. Note this might
46  be slightly different from dependencies required for training a model.
47- ``args`` and ``kwargs`` are passed along to the real callable function.
48- Docstring of the function works as a help message. It explains what does the model do and what
49  are the allowed positional/keyword arguments. It's highly recommended to add a few examples here.
50- Entrypoint function can either return a model(nn.module), or auxiliary tools to make the user workflow smoother, e.g. tokenizers.
51- Callables prefixed with underscore are considered as helper functions which won't show up in :func:`torch.hub.list()`.
52- Pretrained weights can either be stored locally in the GitHub repo, or loadable by
53  :func:`torch.hub.load_state_dict_from_url()`. If less than 2GB, it's recommended to attach it to a `project release <https://help.github.com/en/articles/distributing-large-binaries>`_
54  and use the url from the release.
55  In the example above ``torchvision.models.resnet.resnet18`` handles ``pretrained``, alternatively you can put the following logic in the entrypoint definition.
56
57::
58
59    if pretrained:
60        # For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
61        dirname = os.path.dirname(__file__)
62        checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
63        state_dict = torch.load(checkpoint)
64        model.load_state_dict(state_dict)
65
66        # For checkpoint saved elsewhere
67        checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
68        model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
69
70
71Important Notice
72^^^^^^^^^^^^^^^^
73
74- The published models should be at least in a branch/tag. It can't be a random commit.
75
76
77Loading models from Hub
78-----------------------
79
80Pytorch Hub provides convenient APIs to explore all available models in hub
81through :func:`torch.hub.list()`, show docstring and examples through
82:func:`torch.hub.help()` and load the pre-trained models using
83:func:`torch.hub.load()`.
84
85
86.. automodule:: torch.hub
87
88.. autofunction:: list
89
90.. autofunction:: help
91
92.. autofunction:: load
93
94.. autofunction:: download_url_to_file
95
96.. autofunction:: load_state_dict_from_url
97
98Running a loaded model:
99^^^^^^^^^^^^^^^^^^^^^^^
100
101Note that ``*args`` and ``**kwargs`` in :func:`torch.hub.load()` are used to
102**instantiate** a model. After you have loaded a model, how can you find out
103what you can do with the model?
104A suggested workflow is
105
106- ``dir(model)`` to see all available methods of the model.
107- ``help(model.foo)`` to check what arguments ``model.foo`` takes to run
108
109To help users explore without referring to documentation back and forth, we strongly
110recommend repo owners make function help messages clear and succinct. It's also helpful
111to include a minimal working example.
112
113Where are my downloaded models saved?
114^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
115
116The locations are used in the order of
117
118- Calling ``hub.set_dir(<PATH_TO_HUB_DIR>)``
119- ``$TORCH_HOME/hub``, if environment variable ``TORCH_HOME`` is set.
120- ``$XDG_CACHE_HOME/torch/hub``, if environment variable ``XDG_CACHE_HOME`` is set.
121- ``~/.cache/torch/hub``
122
123.. autofunction:: get_dir
124
125.. autofunction:: set_dir
126
127Caching logic
128^^^^^^^^^^^^^
129
130By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in the
131directory returned by :func:`~torch.hub.get_dir()`.
132
133Users can force a reload by calling ``hub.load(..., force_reload=True)``. This will delete
134the existing GitHub folder and downloaded weights, reinitialize a fresh download. This is useful
135when updates are published to the same branch, users can keep up with the latest release.
136
137
138Known limitations:
139^^^^^^^^^^^^^^^^^^
140Torch hub works by importing the package as if it was installed. There are some side effects
141introduced by importing in Python. For example, you can see new items in Python caches
142``sys.modules`` and ``sys.path_importer_cache`` which is normal Python behavior.
143This also means that you may have import errors when importing different models
144from different repos, if the repos have the same sub-package names (typically, a
145``model`` subpackage). A workaround for these kinds of import errors is to
146remove the offending sub-package from the ``sys.modules`` dict; more details can
147be found in `this GitHub issue
148<https://github.com/pytorch/hub/issues/243#issuecomment-942403391>`_.
149
150A known limitation that is worth mentioning here: users **CANNOT** load two different branches of
151the same repo in the **same python process**. It's just like installing two packages with the
152same name in Python, which is not good. Cache might join the party and give you surprises if you
153actually try that. Of course it's totally fine to load them in separate processes.
154