Source code for torch.distributed.rpc.backend_registry
__all__=["init_backend","backend_registered","construct_rpc_backend_options","register_backend","BackendType","BackendValue"]importcollectionsimportenumfromtypingimportcast,Dict,List,Set,Tupleimporttorchimporttorch.distributedasdistfrom._utilsimport_group_membership_management,_update_group_membershipfrom.importapifrom.importconstantsasrpc_constantsBackendValue=collections.namedtuple("BackendValue",["construct_rpc_backend_options_handler","init_backend_handler"])def_backend_type_repr(self):return"BackendType."+self.name_backend_type_doc=""" An enum class of available backends. PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend. Additional ones can be registered using the :func:`~torch.distributed.rpc.backend_registry.register_backend` function."""# Create an enum type, `BackendType`, with empty members.# Can't handle Function Enum API (mypy bug #9079)BackendType=enum.Enum(value="BackendType",names=dict())# type: ignore[misc]# Unable to assign a function a method (mypy bug #2427)BackendType.__repr__=_backend_type_repr# type: ignore[assignment]ifBackendType.__doc__:BackendType.__doc__=_backend_type_docdefbackend_registered(backend_name):""" Checks if backend_name is registered as an RPC backend. Args: backend_name (str): string to identify the RPC backend. Returns: True if the backend has been registered with ``register_backend``, else False. """returnbackend_nameinBackendType.__members__.keys()defregister_backend(backend_name,construct_rpc_backend_options_handler,init_backend_handler):"""Registers a new RPC backend. Args: backend_name (str): backend string to identify the handler. construct_rpc_backend_options_handler (function): Handler that is invoked when rpc_backend.construct_rpc_backend_options(**dict) is called. init_backend_handler (function): Handler that is invoked when the `_init_rpc_backend()` function is called with a backend. This returns the agent. """globalBackendTypeifbackend_registered(backend_name):raiseRuntimeError("RPC backend {}: already registered".format(backend_name))# Create a new enum type, `BackendType`, with extended members.existing_enum_dict={member.name:member.valueformemberinBackendType}extended_enum_dict=dict({backend_name:BackendValue(construct_rpc_backend_options_handler=construct_rpc_backend_options_handler,init_backend_handler=init_backend_handler,)},**existing_enum_dict)# Can't handle Function Enum API (mypy bug #9079)BackendType=enum.Enum(value="BackendType",names=extended_enum_dict)# type: ignore[misc]# Unable to assign a function a method (mypy bug #2427)BackendType.__repr__=_backend_type_repr# type: ignore[assignment]ifBackendType.__doc__:BackendType.__doc__=_backend_type_docreturnBackendType[backend_name]defconstruct_rpc_backend_options(backend,rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC,init_method=rpc_constants.DEFAULT_INIT_METHOD,**kwargs):returnbackend.value.construct_rpc_backend_options_handler(rpc_timeout,init_method,**kwargs)definit_backend(backend,*args,**kwargs):returnbackend.value.init_backend_handler(*args,**kwargs)def_init_process_group(store,rank,world_size):# Initialize ProcessGroup.process_group_timeout=rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT# We're using a bunch of private APIs here since `new_group` requires the# default group to be initialized.group=dist.ProcessGroupGloo(store,rank,world_size,process_group_timeout)assertgroupisnotNone,"Failed to initialize default ProcessGroup."if(rank!=-1)and(rank!=group.rank()):raiseRuntimeError("rank argument {} doesn't match pg rank {}".format(rank,group.rank()))if(world_size!=-1)and(world_size!=group.size()):raiseRuntimeError("world_size argument {} doesn't match pg size {}".format(world_size,group.size()))returngroupdef_tensorpipe_construct_rpc_backend_options_handler(rpc_timeout,init_method,num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS,_transports=None,_channels=None,**kwargs):from.importTensorPipeRpcBackendOptionsreturnTensorPipeRpcBackendOptions(rpc_timeout=rpc_timeout,init_method=init_method,num_worker_threads=num_worker_threads,_transports=_transports,_channels=_channels,)def_tensorpipe_validate_devices(devices,device_count):returnall(d.type=="cpu"or(d.type=="cuda"and0<=d.index<device_count)fordindevices)# detect if any worker has invalid device_map configurations, and return# reverse device mapsdef_tensorpipe_exchange_and_check_all_device_maps(my_name,my_device_count,my_device_maps,my_devices,group):gathered:List[Tuple[str,int,Dict[str,Dict[torch.device,torch.device]],List[torch.device]]]=[("",0,{},[])for_inrange(group.size())]dist.all_gather_object(gathered,(my_name,my_device_count,my_device_maps,my_devices),group)all_names=[nameforname,_,_,_ingathered]all_device_counts={name:countforname,count,_,_ingathered}all_device_maps={name:map_forname,_,map_,_ingathered}all_devices={name:devicesforname,_,_,devicesingathered}_validate_device_maps(all_names,all_device_counts,all_device_maps,all_devices)# passed all checked, construct reverse mapping and get list of devices handled by this agentreverse_device_maps=_create_reverse_mapping(my_name,all_names,all_device_maps)my_devices=_create_device_list(my_devices,my_device_maps,reverse_device_maps)returnreverse_device_maps,my_devicesdef_validate_device_maps(all_names,all_device_counts,all_device_maps,all_devices,is_static_group=True):fornodeinall_names:devices=all_devices[node]iflen(set(devices))!=len(devices):raiseValueError(f"Node {node} has duplicated devices\n"f"devices = {devices}")ifnot_tensorpipe_validate_devices(devices,all_device_counts[node]):raiseValueError(f"Node {node} has devices with invalid indices\n"f"devices = {devices}\n"f"device count = {all_device_counts[node]}")forsource_nodeinall_names:# For dynamic group (non-static) do not check the target node name since it may not have joined yetifis_static_groupandnotset(all_device_maps[source_node].keys()).issubset(all_names):raiseValueError(f"Node {source_node} has invalid target node names in its device maps\n"f"device maps = {all_device_maps[source_node].keys()}\n"f"node names = {all_names}")fortarget_node,map_inall_device_maps[source_node].items():iflen(set(map_.values()))!=len(map_):raiseValueError(f"Node {source_node} has duplicated target devices "f"in its device map for {target_node}\n"f"device map = {map_}")ifall_devices[source_node]:ifnotset(map_.keys()).issubset(all_devices[source_node]):raiseValueError(f"Node {source_node} has unexpected source devices "f"in its device map for {target_node}\n"f"device map = {map_}\n"f"devices = {all_devices[source_node]}")elifnot_tensorpipe_validate_devices(map_.keys(),all_device_counts[source_node]):raiseValueError(f"Node {source_node} has source devices with invalid indices "f"in its device map for {target_node}\n"f"device map = {map_}\n"f"device count = {all_device_counts[source_node]}")ifall_devices.get(target_node,[]):ifnotset(map_.values()).issubset(all_devices[target_node]):raiseValueError(f"Node {source_node} has unexpected target devices "f"in its device map for {target_node}\n"f"device map = {map_}\n"f"devices = {all_devices[target_node]}")eliftarget_nodeinall_device_countsandnot_tensorpipe_validate_devices(map_.values(),all_device_counts[target_node]):raiseValueError(f"Node {source_node} has target devices with invalid indices "f"in its device map for {target_node}\n"f"device map = {map_}\n"f"device count = {all_device_counts[target_node]}")def_create_device_list(my_devices,my_device_maps,reverse_device_maps):ifnotmy_devices:devices_set:Set[torch.device]=set()for_,map_inmy_device_maps.items():devices_set.update(map_.keys())for_,map_inreverse_device_maps.items():devices_set.update(map_.keys())devices_set.discard(torch.device("cpu"))my_devices=list(devices_set)my_devices=sorted(my_devices,key=lambdad:d.index)returnmy_devicesdef_create_reverse_mapping(my_name,all_names,all_device_maps):reverse_device_maps:Dict[str,Dict[torch.device,torch.device]]={}fornodeinall_names:ifmy_nameinall_device_maps[node]:reverse_device_maps[node]={v:kfork,vinall_device_maps[node][my_name].items()}returnreverse_device_mapsdef_get_device_infos():from.importTensorPipeAgentagent=cast(TensorPipeAgent,api._get_current_rpc_agent())opts=agent._get_backend_options()device_count=torch.cuda.device_count()returndevice_count,opts.device_maps,opts.devicesdef_set_devices_and_reverse_device_map(agent):from.importTensorPipeAgentagent=cast(TensorPipeAgent,agent)# Group state is retrieved from local agent# On initialization, tensorpipe agent retrieves information from all existing workers, so group state is validmy_worker_info=agent.get_worker_info()my_name=my_worker_info.nameall_worker_infos=agent.get_worker_infos()# One round to get device_maps of all workers and construct reverse device mapsall_device_counts,all_device_maps,all_devices,all_names={},{},{},[]forworker_infoinall_worker_infos:worker_name=worker_info.nameifworker_name!=my_name:# TODO: make async?device_count,device_map,devices=api.rpc_sync(worker_name,_get_device_infos)else:opts=agent._get_backend_options()device_count,device_map,devices=torch.cuda.device_count(),opts.device_maps,opts.devicesall_device_counts[worker_name]=device_countall_device_maps[worker_name]=device_mapall_devices[worker_name]=devicesall_names.append(worker_name)_validate_device_maps(all_names,all_device_counts,all_device_maps,all_devices,is_static_group=False)reverse_device_maps=_create_reverse_mapping(my_name,all_names,all_device_maps)# Perform RPC call to all workers, including itself, to include newly joined worker information and device mapsforworker_nameinall_names:# Set device list for each workerall_devices[worker_name]=_create_device_list(all_devices[worker_name],all_device_maps[worker_name],reverse_device_maps)api.rpc_sync(worker_name,_update_group_membership,args=(my_worker_info,all_devices[worker_name],reverse_device_maps,True))def_tensorpipe_init_backend_handler(store,name,rank,world_size,rpc_backend_options):from.importTensorPipeAgentfrom.importTensorPipeRpcBackendOptionsifnotisinstance(store,dist.Store):raiseTypeError("`store` must be a c10d::Store. {}".format(store))ifnotisinstance(rpc_backend_options,TensorPipeRpcBackendOptions):raiseTypeError("`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {}".format(rpc_backend_options))iftorch.cuda.is_available():# It's necessary to initialize PyTorch CUDA states here (e.g.,# CUDACachingAllocator). If this is missing, we could hit errors like# "allocator not initialized", because other processes might send# CUDA-related RPC request to this process before user code in this# process initializes its PyTorch CUDA states.torch.cuda.init()device_count=torch.cuda.device_count()else:device_count=0is_static_group=Trueifworld_sizeelseFalse# world_size is specified so this is a static group (ranks cannot join and leave)ifis_static_group:# The agent's join method is required to behave like a barrier and perform# collective operations, for which it relies on a process group, instead of# re-implementing this on top of RPCs.group=_init_process_group(store,rank,world_size)reverse_device_maps,devices=_tensorpipe_exchange_and_check_all_device_maps(name,device_count,rpc_backend_options.device_maps,rpc_backend_options.devices,group,)# TODO: add try-except and destroy _agent in all processes if any fails.agent=TensorPipeAgent(store,name,rank,world_size,rpc_backend_options,reverse_device_maps,devices,)api._init_rpc_states(agent)# Run one dummy round of RPC to initialize channels/transports. Without# this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC# on that process before rpc.shutdown(), as the agent initialization can# take longer than 5s.api._all_gather(None,timeout=rpc_backend_options.rpc_timeout)# Need a barrier here to make sure no peers leave before the rank0 finishes# _all_gathergroup.barrier().wait()returnagent# initialization for dynamic rpc (ranks can join and leave)else:with_group_membership_management(store,name,True):# Construct TPAgent with empty reverse_device_map and devices# these properties will be updated after initializationagent=TensorPipeAgent(store,name,rank,world_size,rpc_backend_options,{},[],)api._init_rpc_states(agent)try:# Notify all workers in group this rank has joined and set devices and reverse_device_map# This is a synchronous operation that completes once all existing ranks are updated_set_devices_and_reverse_device_map(agent)passexceptException:api.shutdown()raisereturnagentregister_backend("TENSORPIPE",_tensorpipe_construct_rpc_backend_options_handler,_tensorpipe_init_backend_handler,)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.