Skip to content

Files

Latest commit

087d085 · Oct 10, 2024

History

History
190 lines (130 loc) · 7.32 KB

quantization.md

File metadata and controls

190 lines (130 loc) · 7.32 KB

Quantization

Quantization is a technique that can reduce the model size and accelerate its execution with little to no degradation in accuracy. CTranslate2 supports the most common types:

  • 8-bit integers (INT8)
  • 16-bit integers (INT16)
  • 16-bit floating points (FP16)
  • 16-bit brain floating points (BF16)
  • 4-bit AWQ Quantization
See the benchmark results in the main [README](https://github.com/OpenNMT/CTranslate2#benchmarks) to compare the performance and memory usage with and without quantization.

Quantize on model conversion

Enabling the quantization when converting the model is helpful to reduce its size on disk. The converters expose the option quantization that accepts the following values:

  • int8
  • int8_float32
  • int8_float16
  • int8_bfloat16
  • int16
  • float16
  • bfloat16
  • float32

For example,

ct2-opennmt-py-converter --model_path model.pt --quantization int8 --output_dir ct2_model

When the option --quantization is not set, the converted model will be saved with the same type as the original model (typically one of float32, float16, or bfloat16).

Whatever quantization type is selected here, the runtime ensures the model can be loaded and executed efficiently. This implies the model weights are possibly converted to another type when the model is loaded, see {ref}`quantization:implicit type conversion on load`.

For reference, the table below compares the model size on disk for a base Transformer model without shared embeddings and a vocabulary of size 32k:

Quantization Model size
float32 364MB
int16 187MB
float16 182MB
bfloat16 182MB
int8_float32 100MB
int8_float16 95MB
int8_bfloat16 95MB

Quantize on model loading

Quantization can also be enabled or changed when loading the model. The translator exposes the option compute_type that accepts the following values:

  • default: keep the same quantization that was used during model conversion (see {ref}quantization:implicit type conversion on load for exceptions)
  • auto: use the fastest computation type that is supported on this system and device
  • int8
  • int8_float32
  • int8_float16
  • int8_bfloat16
  • int16
  • float16
  • float32
  • bfloat16

For example,

translator = ctranslate2.Translator(model_path, compute_type="int8")
Conversions between all types are supported. For example, you can convert a model with `quantization="int8"` and then execute in full precision with `compute_type="float32"`.

Implicit type conversion on load

By default, the runtime tries to use the type that is saved in the converted model as the computation type. However, if the current platform or backend do not support optimized execution for this computation type (e.g. int16 is not optimized on GPU), then the library converts the model weights to another optimized type. The tables below document the fallback types in prebuilt binaries:

On CPU:

Architecture int8_float32 int8_float16 int8_bfloat16 int16 float16 bfloat16
x86-64 (Intel) int8_float32 int8_float32 int8_float32 int16 float32 float32
x86-64 (other) int8_float32 int8_float32 int8_float32 int8_float32 float32 float32
AArch64/ARM64 (Apple) int8_float32 int8_float32 int8_float32 int8_float32 float32 float32
AArch64/ARM64 (other) int8_float32 int8_float32 int8_float32 int8_float32 float32 float32

On GPU:

Compute Capability int8_float32 int8_float16 int8_bfloat16 int16 float16 bfloat16
>= 8.0 int8_float32 int8_float16 int8_bfloat16 float16 float16 bfloat16
>= 7.0, < 8.0 int8_float32 int8_float16 int8_float32 float16 float16 float32
6.2 float32 float32 float32 float32 float32 float32
6.1 int8_float32 int8_float32 int8_float32 float32 float32 float32
<= 6.0 float32 float32 float32 float32 float32 float32
You can get more information about the detected capabilities of your system by enabling the info logs (set the environment variable `CT2_VERBOSE=1` or call ``ctranslate2.set_log_level(logging.INFO)``).

The supported compute types can also be queried at runtime with the Python function [`ctranslate2.get_supported_compute_types`](python/ctranslate2.get_supported_compute_types.rst).

Supported types

8-bit integers (int8)

Supported on:

  • NVIDIA GPU with Compute Capability >= 7.0 or Compute Capability 6.1
  • x86-64 CPU with the Intel MKL or oneDNN backends
  • AArch64/ARM64 CPU with the Ruy backend

The implementation applies the equation from Wu et al. 2016 to quantize the weights of the embedding and linear layers:

scale[i] = 127 / max(abs(W[i,:]))

WQ[i,j] = round(scale[i] * W[i,j])
This formula corresponds to a symmetric quantization (absolute maximum of the input range instead of separate min/max values).

Non quantized layers are run in the floating point precision of the original model. For example, if the model is saved in float16, the actual quantization type will be int8_float16. This behavior can be changed by selecting more specific quantization types:

  • int8_float32
  • int8_float16
  • int8_bfloat16

16-bit integers (int16)

Supported on:

  • Intel CPU with the Intel MKL backend

The implementation follows the work by Devlin 2017. By default we use one quantization scale per layer. The scale is defined as:

scale = 2^10 / max(abs(W))

As suggested by the author, the idea is to use 10 bits for the input so that the multiplication is 20 bits which gives 12 bits left for accumulation.

Similar to the int8 quantization, only the weights of the embedding and linear layers are quantized to 16-bit integers. The other layers are run in FP32.

16-bit floating points (float16)

Supported on:

  • NVIDIA GPU with Compute Capability >= 7.0

In this mode, all model weights are stored in half precision and all layers are run in half precision.

16-bit brain floating points (bfloat16)

Supported on:

  • NVIDIA GPU with Compute Capability >= 8.0

In this mode, all model weights are stored in BF16 and all layers are run with this type.

4-bit AWQ

Supported on:

  • NVIDIA GPU with Compute Capability >= 7.5

CTranslate2 internally handles the compute type for AWQ quantization. In this mode, all model weights are stored in half precision and all layers are run in half precision. Other parameters like scale and zero are stored in int32.

Steps to use AWQ Quantization:

 ct2-transformers-converter --model TheBloke/Llama-2-7B-AWQ --copy_files tokenizer.model --output_dir ct2_model
  • Run inference as usual with Ctranslate2:
model = ctranslate2.Generator('ct2_model', device='cuda')
outputs = model.generate_batch([tokens])

Currently, CTranslate2 only supports the GEMM and GEMV kernels for AWQ quantization.