| | from typing import * |
| |
|
| | import torch |
| | import torch.distributed.rpc as rpc |
| | from torch import Tensor |
| | from torch._jit_internal import Future |
| | from torch.distributed.rpc import RRef |
| | from typing import Tuple |
| |
|
| |
|
| | module_interface_cls = None |
| |
|
| |
|
| | def forward_async(self, *args, **kwargs): |
| | args = (self.module_rref, self.device, self.is_device_map_set, *args) |
| | kwargs = {**kwargs} |
| | return rpc.rpc_async( |
| | self.module_rref.owner(), |
| | _remote_forward, |
| | args, |
| | kwargs, |
| | ) |
| |
|
| |
|
| | def forward(self, *args, **kwargs): |
| | args = (self.module_rref, self.device, self.is_device_map_set, *args) |
| | kwargs = {**kwargs} |
| | ret_fut = rpc.rpc_async( |
| | self.module_rref.owner(), |
| | _remote_forward, |
| | args, |
| | kwargs, |
| | ) |
| | return ret_fut.wait() |
| |
|
| |
|
| | _generated_methods = [ |
| | forward_async, |
| | forward, |
| | ] |
| |
|
| |
|
| |
|
| |
|
| | def _remote_forward( |
| | module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, *args, **kwargs): |
| | module = module_rref.local_value() |
| | device = torch.device(device) |
| |
|
| | if device.type != "cuda": |
| | return module.forward(*args, **kwargs) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | args = (*args,) |
| | out_args: Tuple[()] = () |
| | for arg in args: |
| | arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,) |
| | out_args = out_args + arg |
| |
|
| | kwargs = {**kwargs} |
| | for k, v in kwargs.items(): |
| | if isinstance(v, Tensor): |
| | kwargs[k] = kwargs[k].to(device) |
| |
|
| | if is_device_map_set: |
| | return module.forward(*out_args, **kwargs) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | ret: Tuple[()] = () |
| | for i in module.forward(*out_args, **kwargs): |
| | i = (i.cpu(),) if isinstance(i, Tensor) else (i,) |
| | ret = ret + i |
| | return ret |
| |
|