|
|
|
|
|
|
|
|
|
|
|
import subprocess |
|
import tempfile |
|
|
|
|
|
class PlasmaArray(object): |
|
""" |
|
Wrapper around numpy arrays that automatically moves the data to shared |
|
memory upon serialization. This is particularly helpful when passing numpy |
|
arrays through multiprocessing, so that data is not unnecessarily |
|
duplicated or pickled. |
|
""" |
|
|
|
def __init__(self, array): |
|
super().__init__() |
|
self.array = array |
|
self.disable = array.nbytes < 134217728 |
|
self.object_id = None |
|
self.path = None |
|
|
|
|
|
self._client = None |
|
self._server = None |
|
self._server_tmp = None |
|
self._plasma = None |
|
|
|
@property |
|
def plasma(self): |
|
if self._plasma is None and not self.disable: |
|
try: |
|
import pyarrow.plasma as plasma |
|
|
|
self._plasma = plasma |
|
except ImportError: |
|
self._plasma = None |
|
return self._plasma |
|
|
|
def start_server(self): |
|
if self.plasma is None or self._server is not None: |
|
return |
|
assert self.object_id is None |
|
assert self.path is None |
|
self._server_tmp = tempfile.NamedTemporaryFile() |
|
self.path = self._server_tmp.name |
|
self._server = subprocess.Popen( |
|
[ |
|
"plasma_store", |
|
"-m", |
|
str(int(1.05 * self.array.nbytes)), |
|
"-s", |
|
self.path, |
|
] |
|
) |
|
|
|
@property |
|
def client(self): |
|
if self._client is None: |
|
assert self.path is not None |
|
self._client = self.plasma.connect(self.path, num_retries=200) |
|
return self._client |
|
|
|
def __getstate__(self): |
|
if self.plasma is None: |
|
return self.__dict__ |
|
if self.object_id is None: |
|
self.start_server() |
|
self.object_id = self.client.put(self.array) |
|
state = self.__dict__.copy() |
|
del state["array"] |
|
state["_client"] = None |
|
state["_server"] = None |
|
state["_server_tmp"] = None |
|
state["_plasma"] = None |
|
return state |
|
|
|
def __setstate__(self, state): |
|
self.__dict__.update(state) |
|
if self.plasma is None: |
|
return |
|
self.array = self.client.get(self.object_id) |
|
|
|
def __del__(self): |
|
if self._server is not None: |
|
self._server.kill() |
|
self._server = None |
|
self._server_tmp.close() |
|
self._server_tmp = None |
|
|