记录黑客技术中优秀的内容,传播黑客文化,分享黑客技术精华

TorchScript: tracing vs. scripting

2022-05-23 20:03

PyTorch provides two methods to turn an nn.Module into agraph represented in TorchScript format: tracing and scripting.This article will:

  1. Compare their pros and cons, with a focus on useful tips for tracing.
  2. Try to convince you that torch.jit.trace should be preferred over torch.jit.scriptfor deployment of non-trivial models.

The second point might be an uncommon opinion:If I Google "tracing vs scripting", the first articlerecommends scripting as default.But tracing has many advantages.In fact, by the time I left, "tracing as default, scripting only when necessary" is thestrategy all detection & segmentation models in Facebook/Meta products are deployed.

Why tracing is better? TL;DR: (i) it will not damage the code quality; (ii) its main limitations can beaddressed by mixing with scripting.

Terminology

We start with some terminologies. Some of them are ambiguous so I give them my owndefinition for this article.

  • Export: refers to the process that turns a model written in eager-mode Pythoncode into a graph that describes the computation.

  • Tracing: An export method. It runs a model with certain inputs, and "traces / records" all the operationsthat are executed into a graph.

    torch.jit.trace is an export API that uses tracing, used like torch.jit.trace(model, input).See its tutorialand API.

  • Scripting: Another export method. It parses the Python source code of the model, and compiles the code into agraph.

    torch.jit.script is an export API that uses scripting, used like torch.jit.script(model).See its tutorialand API.

  • TorchScript: This is an overloaded term

    • It often refers to the representation / format of the exported graph.
    • But sometimes it refers to the scripting export method.

    To avoid confusion, I'll never use "TorchScript" alone in this article.I'll use "TS-format" to refer to the format, and "scripting" to refer to the export method.

    Because this term is used with ambiguity, it may have caused the impression that "scripting" is the"official / preferred" way to create a TS-format model. But that's not necessarily true.

  • (Torch)Scriptable: A model is "scriptable" if torch.jit.script(model) succeeds, i.e. it canbe exported by scripting.

  • Traceable: A model is "traceable" if torch.jit.trace(model, input) succeeds for atypical input.

  • Generalize: A traced model (returned object of trace()) "generalizes" to other inputs, if it has correctbehavior when given other inputs (different from the inputs given during tracing).Scripted models always generalize.

  • Dynamic control flow or data-dependent control flow: control flow where the operatorsto be executed depend on the input data, e.g. for a Tensor x:

    • if x[0] == 4: x += 1 is a dynamic control flow.
    • model: nn.Sequential = ...
      for m in model:
      x = m(x)
      is NOT a dynamic control flow.
      class A(nn.Module):
      backbone: nn.Module
      head: Optiona[nn.Module]
      def forward(self, x):
      x = self.backbone(x)
      if self.head is not None:
      x = self.head(x)
      return x
      is NOT a dynamic control flow.

The Cost of Scriptability

If anyone says "we'll make Python better by writing a compiler for it", you should immediatelybe alarmed and know that this is an extremely difficult goal.Python is too big and too dynamic. A compiler can only support a subset of its syntax features and builtins, at best --the scripting compiler in PyTorch is no exception.

What subset of Python does this compiler support?A rough answer is: the compiler hasgood support for the most basic syntax, but medium to no support for anything more complicated (classes, builtins, dynamic types, etc.).But there is no clear answer: even the developers of the compiler usually need to run the code to see if it can be compiled or not.

Though there isn't a clear list of constraints,I can tell from my experience what impact they have had on large projects:code quality is the cost of scriptability.

Impact on Most Projects

Most projects choose to stay on the "safe side" to only use basic syntax of Python:no/few custom structures, no inheritance, no Union, no **kwargs, no lambda, no dynamic types, etc.

This is because compiler features with "medium-level support" are not robust enough: theymay work in some cases but fail in others.And because there is no clear spec of what is supported,users are unable to reason or workaround the failures.So, they eventually only stay on the safe side.

The terrible consequence is that:developers stop making abstractions / exploring useful language featuresdue to concerns in scriptability.

A related hack that many projects do is to rewrite part of the code for scripting:create a separate, inference-only forward codepath that makes the compiler happy.This also makes the project harder to maintain.

Impact on Detectron2

Detectron2 supports scripting, but the story was a bit different: it did not go downhill in code quality which we value a lot in research.Instead, with some creativity and direct support from PyTorch team (and some volunteered help from Alibaba engineers), we managed to make most modelsscriptable without removing any abstractions.

However, it is not an easy task:we had to add dozens of syntax fixes to the compiler, find creative workarounds,and develop some hacky patches in detectron2 that are inthis file(which honestly could affect maintainability in the long term).I would not recommend other projects to aim for "scriptability without losing abstractions" unlessthey are also closely supported by PyTorch team.

Recommendation

If you think "scripting seems to work for my project",so let's adopt it in full, I might advise against it for the following reasons,based on my past experiences with a few projects that support scripting:

  • What "works" might be brittle (unless you limit yourself to the basic syntax):It might compile now, but one day you'll add a few more innocent features to your modeland find that the compiler refuses it.

  • Basic syntax is not enough:Even if abstractions don't appear necessary to your project at the moment,if the project is expected to grow, it will require abstractions in the future.

    Take a multi-task detector for example:

    1. There could be 10s of inputs, so it's preferable to use some classes.
    2. The same data can have different representations (e.g. different ways to represent a segmentation mask),which demands Union or more dynamic types.
    3. There are many architectural choices of a detector, which makes inheritance useful.

    Large projects definitely need abstractions to stay healthy.

  • Code quality could severely deteriorate:Ugly code starts to accumulate, because clean code sometimes just doesn't compile.Also, due to syntax limitations of the compiler,abstractions cannot be easily made to clean up the ugliness.The health of the project gradually goes downhill.

Below is a complaint in PyTorch issues.The issue itself is just one small papercut of scripting,but I've heard similar complaints many times.The status-quo is: scripting forces you to write ugly code, so only use it when necessary.

Make a Model Trace and Generalize

The Cost of Traceability

What it takes to make a model traceable is very clear, and has a much smaller impact on code health.

  1. First, neither scripting nor tracing can work if the model is not even a proper single-device, connected graph representable in TS-format.For example, if the model has DataParallel submodules, or if the modelconverts tensors to numpy arrays and calls OpenCV functions, etc, you'll have to refactor it.

    Apart from this obvious issue, there are only two requirements for traceability.

  2. Input/output format

    Model's inputs/outputs have to be Union[Tensor, Tuple[Tensor]] to be traceable.

    This might appear worse than scripting, because scripting at least has good support forstrongly-typed dicts.However, here the format constraint does not apply to submodules:submodules can use any input/output format: classes, kwargs, anything that Python supports.

    The format requirement only applies to the outer-most model, so it's very easy to address.If the model uses richer formats, just create a simple wrapper around it that converts to/fromTuple[Tensor].Detectron2 even automates this for all its models by a universal wrapperlike this:

    outputs = model(inputs)   # inputs/outputs are rich structure
    # torch.jit.trace(model, inputs) # FAIL! unsupported format
    adapter = TracingAdapter(model, inputs)
    traced = torch.jit.trace(adapter, adapter.flattened_inputs) # Can now trace the model

    # Traced model can only produce flattened outputs (tuple of tensors):
    flattened_outputs = traced(*adapter.flattened_inputs)
    # Adapter knows how to convert it back to the rich structure (new_outputs == outputs):
    new_outputs = adapter.outputs_schema(flattened_outputs)
  3. Symbolic shapes:

    Expressions like tensor.size(0), tensor.size()[1], tensor.shape[2]returns int in eager mode, but Tensor in tracing mode.This is necessary to allow shape computation to becaptured as symbolic operations in the graph.An example is given in the next section about generalization.

    Due to different return types,a model may be untraceable if parts of it assume shapes are integers.This can often be fixed easily by handling both types.A helpful function is torch.jit.is_tracingwhich checks if the code is executed in tracing mode.

That's all it takes for traceability - most importantly, any Python syntax is allowed in model implementation, because tracing does not careabout syntax at all.

Generalization Problem

Just being "traceable" is not sufficient.The biggest problem with tracing, is that it may not generalize to other inputs.It happens in the following cases:

  1. Dynamic control flow:

    >>> def f(x):
    ... return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
    >>> m = torch.jit.trace(f, torch.tensor(3))
    >>> print(m.code)
    def f(x: Tensor) -> Tensor:
    return torch.sqrt(x)

    In this example, due to dynamic control flow,the trace only keeps one branch of the condition, and will not generalize to certain (negative) inputs.

  2. Capture variables as constants:

    >>> a, b = torch.rand(1), torch.rand(2)
    >>> def f1(x): return torch.arange(x.shape[0])
    >>> def f2(x): return torch.arange(len(x))
    >>> torch.jit.trace(f1, a)(b)
    tensor([0, 1])
    >>> torch.jit.trace(f2, a)(b)
    tensor([0]) # WRONG!
    >>> print(torch.jit.trace(f1, a).code, torch.jit.trace(f2, a).code)
    def f1(x: Tensor) -> Tensor:
    _0 = ops.prim.NumToTensor(torch.size(x, 0))
    _1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
    return _1
    def f2(x: Tensor) -> Tensor:
    _0 = torch.arange(1, dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
    return _0

    Intermediate computation results may be captured as constants, using thevalue observed during tracing. This causes the trace to not generalize.

    In addition to len(), this issue can also appear in:

    • .item() which converts tensors to int/float.
    • Any other code that converts torch types to numpy/python primitives.
    • A few problematic operators, e.g. advanced indexing.
  3. Capture device:

    >>> def f(x):
    ... return torch.arange(x.shape[0], device=x.device)
    >>> m = torch.jit.trace(f, torch.tensor([3]))
    >>> print(m.code)
    def f(x: Tensor) -> Tensor:
    _0 = ops.prim.NumToTensor(torch.size(x, 0))
    _1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
    return _1
    >>> m(torch.tensor([3]).cuda()).device
    device(type='cpu') # WRONG!

    Similarly, operators that accept a device argument will remember the device used during tracing (this canbe seen in m.code).So the trace may not generalize to inputs on a different device.Such generalization is almost never needed, because deployment usually has a target device.

Let Tracing Generalize

The above problems are annoying, but they can be successfully addressed by good practice and tools:

  • Pay attention to TracerWarning: In the first two examples above, torch.jit.trace actually emits warnings.The first example prints:

    a.py:3: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
    We can't record the data flow of Python values, so this value will be treated as a constant in the future.
    This means that the trace might not generalize to other inputs!
    if x.sum() > 0:

    Paying attention to these warnings (or even better, parse them) will expose most generalization problems of tracing.

    Note that the third example above does not print warnings because tracing was not designed to support such generalization at all.

  • Unittests for parity: Unittests should be done after export and before deployment, to verify thatthe exported model produces the same outputs as the original eager-mode model, i.e.

    assert allclose(torch.jit.trace(model, input1)(input2), model(input2))

    If generalization across shapes is needed (not always needed), input2 should have differentshapes from input1.

    Detectron2 has many generalization tests, e.g. thisand this.Once a gap is found, inspecting the code of the exported TS-format model can uncover the place whereit fails to generalize.

  • Avoid unnecessary "special case" control flows:avoid code like

    if x.numel() > 0:
    output = self.layers(x)
    else:
    output = torch.zeros(...)

    that handles special cases such as empty inputs.Instead, improve self.layers or its underlying kernel so it supports empty inputs.This would produce cleaner code and also improve tracing.This is why I'm involved in many PyTorch issues that improve support for emptyinputs, such as#12013,#36530,#56998.

  • Use symbolic shapes: as mentioned above, tensor.size() returns Tensor during tracing, sothat computation using shapes are captured in the graph.Users should avoid accidentally turning tensor shapes into constants:

    • Use tensor.size(0) instead of len(tensor) because the latter is an int.For custom classes, implement a .size method or use .__len__() instead of len(), e.g. here.
    • Do not convert sizes by int() or torch.as_tensor because they will capture constants.This helper functionis useful to convert sizes into a tensor, in a way that works in both tracing and eager mode.
  • Mix tracing and scripting: they can be mixed together, so you can use scriptingon the small portion of code that tracing does not work correctly.This can fix almost all problems of tracing. More on this below.

Mix Tracing and Scripting

Tracing and scripting both have their own problems, and thebest solution is usually to mix them together.This gives us the best of both worlds.

To minimize the negative impact on code quality,we should use tracing for the majority of logic, and use scripting only when necessary.

  1. Use @script_if_tracing: Inside torch.jit.trace, the @script_if_tracingdecorator can compile functions by scripting.Typically, this only requires a small refactor of the forward logic to separate the parts that need tobe compiled (the parts with control flow):

    def forward(self, ...):
    # ... some forward logic
    @torch.jit.script_if_tracing
    def _inner_impl(x, y, z, flag: bool):
    # use control flow, etc.
    return ...
    output = _inner_impl(x, y, z, flag)
    # ... other forward logic

    By scripting only the parts that need it,the code quality damage is strictly smaller than making the entire model scriptable,and it does not affect the module's forward interface at all.

    In fact, for most vision models, dynamic control flow is needed only in a few submodules whereit's easy to be scriptable.To show how rare it is needed, the entire detectron2 only has two functions decorated with @script_if_tracing due to control flows:paste_masksand heatmaps_to_keypoints,both for post-processing only.A few other functions are decorated with @script_if_tracing to generalize across devices (a very rare requirement).

  2. Use scripted / traced submodules:

    model.submodule = torch.jit.script(model.submodule)
    torch.jit.trace(model, inputs)

    In the above example, imagine that submodule cannot be traced correctly, so we script it before tracing.However I do not recommend it.If possible, I will suggest using @script_if_tracinginside submodule.forward instead,so that scripting is limited to the internals of the submodule,without affecting the module's interface.

    And similarly,

    model.submodule = torch.jit.trace(model.submodule, submodule_inputs)
    torch.jit.script(model)

    this uses a traced submodule during scripting.This looks nice, but is not so useful in practice: it will affect the interfaceof submodule, requiring it to only accept/return Tuple[Tensor] -- this is abig constraint that might hurt code quality even more than scripting.

    A rare scenario where "tracing a submodule" is useful, is this:

    class A(nn.Module):
    def forward(self, x):
    # Dispatch to different submodules based on a dynamic, data-dependent condition:
    return self.submodule1(x) if x.sum() > 0 else self.submodule2(x)

    @script_if_tracing cannot compile such control flow because it only supports pure functions.If submodule{1,2} are complex and cannot be scripted,using traced submodules in a scripted parent A is the best option.

  3. Merge multiple traces:

    Scripted models support two more features that traced models don't:

    • Control flow conditioned on attributes: a scripted module can have mutatable attributes (e.g. a boolean flag)that affect control flows. Traced modules do not have control flows.
    • Multiple methods: a traced module only supports forward(), but a scripted module can havemultiple methods.

    Actually, both features above are doing the same thing: they allow an exported model to be used indifferent ways, i.e. execute different sequences of operators as requested by the caller.

    Below is an example scenario where such feature is useful: if Detector is scripted, the caller can mutate itsdo_keypoint attribute to control its behavior, or call predict_keypoint methoddirectly if needed.

    class Detector(nn.Module):
    do_keypoint: bool

    def forward(self, img):
    box = self.predict_boxes(img)
    if self.do_keypoint:
    kpts = self.predict_keypoint(img, box)

    @torch.jit.export
    def predict_boxes(self, img): pass

    @torch.jit.export
    def predict_keypoint(self, img, box): pass

    This requirement is not seen very often. But if needed, how to achieve this in tracing?I have a solution that's not very clean:

    Tracing can only capture one sequence of operators, so the natural way is to trace the model twice:

    det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
    det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)

    We can then alias their weights (to not duplicate the storage), and merge thetwo traces into one module to script.

    det2.submodule.weight = det1.submodule.weight
    class Wrapper(nn.ModuleList):
    def forward(self, img, do_keypoint: bool):
    if do_keypoint:
    return self[0](img)
    else:
    return self[1](img)
    exported = torch.jit.script(Wrapper([det1, det2]))

Performance

If a model is both traceable and scriptable,tracing always generates same or simpler graph (therefore likely faster).

Why?Because scripting tries to faithfully representyour Python code, even some of it are unnecessary. For example:it is not always smart enough to realize that someloops or data structures in the Python code are actually static and can be removed:

class A(nn.Module):
def forward(self, x1, x2, x3):
z = [0, 1, 2]
xs = [x1, x2, x3]
for k in z: x1 += xs[k]
return x1
model = A()
print(torch.jit.script(model).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
# z = [0, 1, 2]
# xs = [x1, x2, x3]
# x10 = x1
# for _0 in range(torch.len(z)):
# k = z[_0]
# x10 = torch.add_(x10, xs[k])
# return x10
print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
# x10 = torch.add_(x1, x1)
# x11 = torch.add_(x10, x2)
# return torch.add_(x11, x3)

This example is very simple, so it actually has workarounds for scripting (use tuple instead of list),or the loop might get optimized in a later optimization pass.But the point is: the graph compiler is not always smart enough. For complicated models, scripting mightgenerate a graph with unnecessary complexity that's hard to optimize.

Concluding Thoughts

Tracing has clear limitations:I spent most of this article talking about the limitations of tracing and how to fix them.I actually think this is the advantage of tracing: it has clear limitations (and solutions),so you can reason about whether it works.On the contrary, scripting for advanced Python syntax is more like a black box:no one knows if it works before trying.

Tracing has small blast radius:Both tracing and scripting affect how code can be written, but tracing has a much smaller blastradius, and causes much less damage:

  • It limits the input/output format, but on the outer-most module only. (And the issue can be automaticallysolved as discussed above.)
  • It needs some code changes to generalize (e.g. to mix scripting in tracing), but these changes only go into theinternal implementation of the affected modules, not their interfaces.

On the other hand, scripting has an impact on:

  • The interface of every module/submodule involved.
    • IMO, this is the biggest damage:Advanced syntax features are needed in interfaces, and I'm not willing to compromise on interface design.
    • This may end up affecting training as well because interface is often shared between training and inference.
  • Pretty much every line of code in the inference forward path.

Control flow vs. other Python syntax:PyTorch is loved by its users because they can "just write Python", and most importantly writePython control flows. But other syntax of Python are important as well.If being able to write Python control flow (scripting) means losing other great syntax,I'd rather give up on the ability to write Python control flow.

In fact, if PyTorch is less obsessed with Python control flow, and offers mesymbolic control flows such as torch.cond like this (similar to the API of tf.cond):

def f(x):
return torch.cond(x.sum() > 0, lambda: torch.sqrt(x), lambda: torch.square(x))

Then f could be traced correctly and I would be happy to use this, no longer having to worryabout scripting.TensorFlow AutoGraphis a great example that automates this idea.


知识来源: https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/

阅读:57705 | 评论:0 | 标签:Tor

想收藏或者和大家分享这篇好文章→复制链接地址

“TorchScript: tracing vs. scripting”共有0条留言

发表评论

姓名:

邮箱:

网址:

验证码:

黑帝公告 📢

永久免费持续更新精选优质黑客技术文章Hackdig,帮你成为掌握黑客技术的英雄

↓赞助商 🙇🧎

标签云 ☁