Source code for torch.ao.quantization.fake_quantize
"""This module implements modules which are used to perform fake quantizationduring QAT."""importtorchfromtorch.nnimportModulefromtorch.ao.quantization.observerimport(MovingAverageMinMaxObserver,HistogramObserver,MovingAveragePerChannelMinMaxObserver,FixedQParamsObserver,default_fixed_qparams_range_0to1_observer,default_fixed_qparams_range_neg1to1_observer,_with_args,)importrefromabcimportABC,abstractmethodfromtypingimportAny,Tupledef_is_per_channel(qscheme:'torch.qscheme')->bool:returnqschemein[torch.per_channel_symmetric,torch.per_channel_affine,torch.per_channel_affine_float_qparams]def_is_per_tensor(qscheme:'torch.qscheme')->bool:returnqschemein[torch.per_tensor_symmetric,torch.per_tensor_affine]def_is_symmetric_quant(qscheme:'torch.qscheme')->bool:returnqschemein[torch.per_tensor_symmetric,torch.per_channel_symmetric]def_is_float_qparams(qscheme:'torch.qscheme')->bool:returnqschemein[torch.per_channel_affine_float_qparams,]
[docs]classFakeQuantizeBase(ABC,Module):r""" Base fake quantize module Any fake quantize implementation should derive from this class. Concrete fake quantize module should follow the same API. In forward, they will update the statistics of the observed Tensor and fake quantize the input. They should also provide a `calculate_qparams` function that computes the quantization parameters given the collected statistics. """fake_quant_enabled:torch.Tensorobserver_enabled:torch.Tensordef__init__(self):super().__init__()# fake_quant_enabled and observer_enabled are buffers to support their# replication in DDP. Data type is uint8 because NCCL does not support# bool tensors.self.register_buffer('fake_quant_enabled',torch.tensor([1],dtype=torch.uint8))self.register_buffer('observer_enabled',torch.tensor([1],dtype=torch.uint8))@abstractmethoddefforward(self,x):pass@abstractmethoddefcalculate_qparams(self,**kwargs):pass@torch.jit.exportdefenable_fake_quant(self,enabled:bool=True)->None:self.fake_quant_enabled[0]=1ifenabledelse0@torch.jit.exportdefdisable_fake_quant(self):self.enable_fake_quant(False)@torch.jit.exportdefenable_observer(self,enabled:bool=True)->None:self.observer_enabled[0]=1ifenabledelse0@torch.jit.exportdefdisable_observer(self):self.enable_observer(False)with_args=classmethod(_with_args)
[docs]classFakeQuantize(FakeQuantizeBase):r""" Simulate the quantize and dequantize operations in training time. The output of this module is given by:: x_out = ( clamp(round(x/scale + zero_point), quant_min, quant_max) - zero_point ) * scale * :attr:`scale` defines the scale factor used for quantization. * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that statistics can still be updated. * :attr:`observer_enabled` controls statistics collection on tensors * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization, allowable values are torch.qint8 and torch.quint8. Args: observer (module): Module for observing statistics on input tensors and calculating scale and zero-point. observer_kwargs (optional): Arguments for the observer module Attributes: activation_post_process (Module): User provided module that collects statistics on the input tensor and provides a method to calculate scale and zero-point. """scale:torch.Tensorzero_point:torch.Tensordef__init__(self,observer=MovingAverageMinMaxObserver,quant_min=None,quant_max=None,**observer_kwargs):super().__init__()# Populate quant_min/quant_max to observer_kwargs if validifquant_minisnotNoneandquant_maxisnotNone:assertquant_min<=quant_max, \
'quant_min must be less than or equal to quant_max'dtype=observer_kwargs.get("dtype",torch.quint8)ifhasattr(observer,"p"):# In case observer is _PartialWrapper, dtype can be stored in# observer.p.keywords["dtype"]dtype=getattr(getattr(observer,"p",{}),"keywords",{}).get("dtype",dtype)asserttorch.iinfo(dtype).min<=quant_min,'quant_min out of bound'assertquant_max<=torch.iinfo(dtype).max,'quant_max out of bound'observer_kwargs.update({"quant_min":quant_min,"quant_max":quant_max})self.activation_post_process=observer(**observer_kwargs)# TODO: keeping self.quant_min/max for BC; remove after a couple releases# Users should use self.activation_post_process.quant_minself.quant_min=self.activation_post_process.quant_minself.quant_max=self.activation_post_process.quant_maxif_is_float_qparams(self.activation_post_process.qscheme):zero_point_dtype=torch.floatelse:zero_point_dtype=torch.intself.register_buffer('scale',torch.tensor([1.0],dtype=torch.float))self.register_buffer('zero_point',torch.tensor([0],dtype=zero_point_dtype))self.dtype=self.activation_post_process.dtypeself.qscheme=self.activation_post_process.qschemeself.ch_axis=self.activation_post_process.ch_axis \
ifhasattr(self.activation_post_process,'ch_axis')else-1assert_is_per_channel(self.qscheme)or \
_is_per_tensor(self.qscheme), \
'Only per channel and per tensor quantization are supported in fake quantize'+ \
' got qscheme: '+str(self.qscheme)self.is_per_channel=_is_per_channel(self.qscheme)@torch.jit.exportdefcalculate_qparams(self):returnself.activation_post_process.calculate_qparams()defforward(self,X):ifself.observer_enabled[0]==1:self.activation_post_process(X.detach())_scale,_zero_point=self.calculate_qparams()_scale,_zero_point=_scale.to(self.scale.device),_zero_point.to(self.zero_point.device)ifself.scale.shape!=_scale.shape:self.scale.resize_(_scale.shape)self.zero_point.resize_(_zero_point.shape)self.scale.copy_(_scale)self.zero_point.copy_(_zero_point)ifself.fake_quant_enabled[0]==1:ifself.is_per_channel:X=torch.fake_quantize_per_channel_affine(X,self.scale,self.zero_point,self.ch_axis,self.activation_post_process.quant_min,self.activation_post_process.quant_max)else:X=torch.fake_quantize_per_tensor_affine(X,self.scale,self.zero_point,self.activation_post_process.quant_min,self.activation_post_process.quant_max)returnX@torch.jit.exportdefextra_repr(self):return'fake_quant_enabled={}, observer_enabled={}, ' \
'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
'scale={}, zero_point={}'.format(self.fake_quant_enabled,self.observer_enabled,self.activation_post_process.quant_min,self.activation_post_process.quant_max,self.dtype,self.qscheme,self.ch_axis,self.scale,self.zero_point)def_save_to_state_dict(self,destination,prefix,keep_vars):# We cannot currently register scalar values as buffers, so need to manually# specify serialization here.super(FakeQuantize,self)._save_to_state_dict(destination,prefix,keep_vars)destination[prefix+'scale']=self.scaledestination[prefix+'zero_point']=self.zero_pointdef_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs):# Removing this function throws an error that the the size of the loaded tensor does not match the original size# i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.local_state=['scale','zero_point']fornameinlocal_state:key=prefix+nameifkeyinstate_dict:val=state_dict[key]# Custom handling to allow loading scale and zero_point# of size N into uninitialized buffers of size 0. The# buffers are resized here, and the values are copied in# the default state_dict loading code of the parent.ifname=='scale':self.scale.resize_(val.shape)else:assertname=='zero_point'self.zero_point.resize_(val.shape)# For torchscript module we need to update the attributes here since we do not# call the `_load_from_state_dict` function defined module.pyiftorch.jit.is_scripting():ifname=='scale':self.scale.copy_(val)else:assertname=='zero_point'self.zero_point.copy_(val)elifstrict:missing_keys.append(key)super(FakeQuantize,self)._load_from_state_dict(state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs)
[docs]classFixedQParamsFakeQuantize(FakeQuantize):""" Simulate quantize and dequantize with fixed quantization parameters in training time. Only per tensor quantization is supported. """def__init__(self,observer):super().__init__(observer=observer)asserttype(self.activation_post_process)==FixedQParamsObserver,\
"%s's observer must be a %s"%(self.__class__.__name__,FixedQParamsObserver.__name__)self._observer_ctr=observerself.scale=self.activation_post_process.scaleself.zero_point=self.activation_post_process.zero_pointassert_is_per_tensor(self.qscheme),'Only per tensor quantization is supported'+ \
' FixedQParamsFakeQuantize module, got qscheme:'+str(self.qscheme)@torch.jit.exportdefcalculate_qparams(self):returnself.scale,self.zero_point@torch.jit.exportdefextra_repr(self):return'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(self.fake_quant_enabled,self.observer_enabled,self.scale,self.zero_point,self.dtype,self.activation_post_process.quant_min,self.activation_post_process.quant_max,self.qscheme)
[docs]classFusedMovingAvgObsFakeQuantize(FakeQuantize):r"""Fused module that is used to observe the input tensor (compute min/max), compute scale/zero_point and fake_quantize the tensor. This module uses calculation similar MovingAverageMinMaxObserver for the inputs, to compute the min/max values in order to compute the scale/zero_point. The qscheme input in the observer is used to differentiate between symmetric/affine quantization scheme. The output of this module is given by x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the base class. """def__init__(self,observer:Any=MovingAverageMinMaxObserver,quant_min:int=0,quant_max:int=255,**observer_kwargs:Any)->None:super().__init__(observer,quant_min,quant_max,**observer_kwargs)assertisinstance(self.activation_post_process,(MovingAverageMinMaxObserver,MovingAveragePerChannelMinMaxObserver)),\
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"self.register_buffer("fake_quant_enabled",torch.tensor([1],dtype=torch.long))self.register_buffer("observer_enabled",torch.tensor([1],dtype=torch.long))self.is_symmetric_quant=_is_symmetric_quant(self.activation_post_process.qscheme)@torch.jit.exportdefcalculate_qparams(self)->Tuple[torch.Tensor,torch.Tensor]:returnself.activation_post_process.calculate_qparams()@torch.jit.exportdefextra_repr(self)->str:return("fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ""dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format(self.fake_quant_enabled,self.observer_enabled,self.scale,self.zero_point,self.dtype,self.activation_post_process.quant_min,self.activation_post_process.quant_max,self.qscheme,self.activation_post_process.reduce_range,))defforward(self,X:torch.Tensor)->torch.Tensor:returntorch.fused_moving_avg_obs_fake_quant(X,self.observer_enabled,self.fake_quant_enabled,self.activation_post_process.min_val,self.activation_post_process.max_val,self.scale,self.zero_point,self.activation_post_process.averaging_constant,self.activation_post_process.quant_min,self.activation_post_process.quant_max,self.ch_axis,self.is_per_channel,self.is_symmetric_quant,)
default_fake_quant=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=0,quant_max=255,dtype=torch.quint8,qscheme=torch.per_tensor_affine,reduce_range=True)"""Default fake_quant for activations."""default_weight_fake_quant=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=-128,quant_max=127,dtype=torch.qint8,qscheme=torch.per_tensor_symmetric,reduce_range=False)"""Default fake_quant for weights.Observer is memoryless since averaging_constant is 1."""default_dynamic_fake_quant=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=0,quant_max=255,dtype=torch.quint8,averaging_constant=1)"""Default dynamic fake_quant for activations."""default_fixed_qparams_range_neg1to1_fake_quant=(FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_neg1to1_observer))default_fixed_qparams_range_0to1_fake_quant=(FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer))# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releasesdefault_symmetric_fixed_qparams_fake_quant=default_fixed_qparams_range_neg1to1_fake_quantdefault_affine_fixed_qparams_fake_quant=default_fixed_qparams_range_0to1_fake_quantdefault_per_channel_weight_fake_quant=FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,quant_min=-128,quant_max=127,dtype=torch.qint8,qscheme=torch.per_channel_symmetric,reduce_range=False,ch_axis=0)"""Default fake_quant for per-channel weights.Observer is memoryless since averaging_constant is 1."""default_embedding_fake_quant=FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,qscheme=torch.per_channel_affine_float_qparams,dtype=torch.quint8,quant_min=0,quant_max=255,ch_axis=0,averaging_constant=1)"""Default fake_quant for embeddings.Observer is memoryless since averaging_constant is 1."""default_embedding_fake_quant_4bit=FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,qscheme=torch.per_channel_affine_float_qparams,ch_axis=0,dtype=torch.quint4x2,averaging_constant=1)default_histogram_fake_quant=FakeQuantize.with_args(observer=HistogramObserver,quant_min=0,quant_max=255,dtype=torch.quint8,qscheme=torch.per_tensor_affine,reduce_range=True)"""Fake_quant for activations using a histogram.."""default_fused_act_fake_quant=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=0,quant_max=255,dtype=torch.quint8,)"""Fused version of `default_fake_quant`, with improved performance."""default_fused_wt_fake_quant=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=-128,quant_max=127,dtype=torch.qint8,qscheme=torch.per_tensor_symmetric)"""Fused version of `default_weight_fake_quant`, with improved performance."""default_fused_per_channel_wt_fake_quant=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,quant_min=-128,quant_max=127,dtype=torch.qint8,qscheme=torch.per_channel_symmetric)"""Fused version of `default_per_channel_weight_fake_quant`, with improved performance."""fused_wt_fake_quant_range_neg_127_to_127=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=-127,quant_max=127,dtype=torch.qint8,qscheme=torch.per_tensor_symmetric,eps=2**-12)"""Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128."""fused_per_channel_wt_fake_quant_range_neg_127_to_127=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=-127,quant_max=127,dtype=torch.qint8,qscheme=torch.per_channel_symmetric,eps=2**-12)"""Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128."""def_is_fake_quant_script_module(mod):''' Returns true if given mod is an instance of FakeQuantize script module. '''ifisinstance(mod,torch.jit.RecursiveScriptModule):# qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'suffix=mod._c.qualified_name.split('.',1)[1]name=re.sub(r'\.___torch_mangle_\d+','',suffix)returnname=='torch.ao.quantization.fake_quantize.FakeQuantize'or \
name=='torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'returnFalsedefdisable_fake_quant(mod):""" Disable fake quantization for this module, if applicable. Example usage:: # model is any PyTorch model model.apply(torch.ao.quantization.disable_fake_quant) """ifisinstance(mod,FakeQuantizeBase)or_is_fake_quant_script_module(mod):mod.disable_fake_quant()defenable_fake_quant(mod):""" Enable fake quantization for this module, if applicable. Example usage:: # model is any PyTorch model model.apply(torch.ao.quantization.enable_fake_quant) """ifisinstance(mod,FakeQuantizeBase)or_is_fake_quant_script_module(mod):mod.enable_fake_quant()defdisable_observer(mod):""" Disable observation for this module, if applicable. Example usage:: # model is any PyTorch model model.apply(torch.ao.quantization.disable_observer) """ifisinstance(mod,FakeQuantizeBase)or_is_fake_quant_script_module(mod):mod.disable_observer()defenable_observer(mod):""" Enable observation for this module, if applicable. Example usage:: # model is any PyTorch model model.apply(torch.ao.quantization.enable_observer) """ifisinstance(mod,FakeQuantizeBase)or_is_fake_quant_script_module(mod):mod.enable_observer()
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.