# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module is a general-purpose subprocess-based task runner.
"""
from __future__ import annotations
import os
import subprocess
import time
import typing
from collections import OrderedDict
from typing import List
# pylint: disable=R1732,R1710,R1721
[docs]class Task:
"""Task is an object containing a bash command,
process for the command, and output of the process.
"""
def __init__(
self, idx: typing.Union[int, str], cmd: str, name: str, **kwargs
) -> None:
"""
Parameters
----------
idx : Union[int, str]
unique id for the task
cmd : str
bash command for the task
name : str
alias name of the task
"""
self._finished = False
self._is_timeout = False
self._failed = False
self._idx = idx
self._cmd = cmd
self._name = name
self._ret = None
self._assigned_dev = None
self._proc = None
self._timestamp = 0
self._stdout = ""
self._stderr = ""
self._kwargs = kwargs
def __call__(self, dev_id: int) -> None:
"""Execute the bash command with a new subprocess.
Parameters
----------
dev_id : int
Target execution device id.
"""
self._assigned_dev = dev_id
use_shell = False
if "shell" in self._kwargs:
use_shell = self._kwargs["shell"]
env = os.environ.copy()
if "dev_flag" in self._kwargs:
env[self._kwargs["dev_flag"]] = str(dev_id)
self._proc = subprocess.Popen(
self._cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
shell=use_shell,
)
self._timestamp = time.time()
[docs] def is_running(self) -> bool:
"""Check whether the task process is still running.
Returns
-------
bool
Whether the task process is still running
"""
return self._proc is not None
[docs] def is_finished(self) -> bool:
"""Check whether the task is finished
Returns
-------
bool
Whether the task is finished
"""
return self._finished
[docs] def is_timeout(self) -> bool:
"""Check whether the task is timeout
Returns
-------
bool
Whether the task is timeout
"""
return self._is_timeout
[docs] def poll(self, current_time, timeout) -> bool:
"""Given the current time, check whether
the task is running, finished or timed out.
Parameters
----------
current_time : int
Current timestamp
timeout : int
Timeout time
Returns
-------
bool
Whether the task is finished
"""
if self.is_running() is False:
return False
# handle timeout job
step = current_time - self._timestamp
if step > timeout:
self._proc.kill()
self._finished = True
self._is_timeout = True
self._failed = True
return True
# handle finished job
if self._proc.poll() is not None:
self._finished = True
return self._finished
[docs] def pull(self, fproc: typing.Callable) -> None:
"""Pull stdout & stderr from process,
process stdout & stderr with fproc, and set the output for the task.
Parameters
----------
fproc : Callable
Process function of the task given stdout & stderr
"""
if self._failed:
return None
self._stdout = self._proc.stdout.read().decode("utf-8")
self._stderr = self._proc.stderr.read().decode("utf-8")
fproc(self)
[docs] def is_failed(self) -> bool:
"""Check whether the task is failed
Returns
-------
bool
Whether the task is failed
"""
return self._failed
[docs] def assigned_dev(self) -> int:
"""Return the assigned device id for the task
Returns
-------
int
Assigned device id
"""
return self._assigned_dev
def __del__(self) -> None:
"""Clean up process resource"""
if self._proc:
if self._proc.stdout:
self._proc.stdout.close()
if self._proc.stderr:
self._proc.stderr.close()
[docs]class DeviceFarm:
"""Device Farm is a stateful object to
schedule and assigns a task to the available devices.
Devices are logical devices, can be CPUs or GPUs.
"""
def __init__(self, devs: List[int]) -> None:
"""Initialize a Device Farm given a list of device ids.
Parameters
----------
devs : List[int]
List of device ids in int
"""
if isinstance(devs, int):
devs = list(range(devs))
assert isinstance(devs, list)
self._dev_stats = OrderedDict()
self._devs = devs
for dev in devs:
self._dev_stats[dev] = False
[docs] def next_idle_dev(self) -> typing.Optional[int]:
"""Return the next idle (available) device id
Returns
-------
Union[None, int]
The next idle device id. If all devices are busy, return None
"""
ret_id = None
for dev_id, dev_status in self._dev_stats.items():
if dev_status is False:
ret_id = dev_id
self._dev_stats[ret_id] = True
return ret_id
[docs] def reset_dev_state(self, dev_id: int) -> None:
"""Rest the device id state to idle
Parameters
----------
dev_id : int
The id of device will be reset
"""
self._dev_stats[dev_id] = False
[docs] def reset_all(self) -> None:
"""Reset all devices to be idle"""
for dev in self._devs:
self._dev_stats[dev] = False
[docs]class BaseRunner:
"""Genetic subprocess task runner for different purposes"""
def __init__(self, devs: List[int], tag: str, timeout: int = 10) -> None:
"""
Parameters
----------
devs : List[int]
List of device ids for tasks.
tag : str
Runner's name tag
timeout : int, optional
Timeout value. Default is 10 (seconds).
"""
self._tag = tag
self._devs = DeviceFarm(devs)
self._timeout = timeout
self._finished_tasks = set()
self._queue = []
[docs] def join(self) -> None:
"""Waiting until all tasks are finished."""
while True:
all_finished = True
current_time = time.time()
for task in self._queue:
all_finished = all_finished and task.is_finished()
if task._idx in self._finished_tasks:
continue
if task.is_running() is False:
next_dev = self._devs.next_idle_dev()
if next_dev is None:
continue
task(next_dev)
continue
if task.poll(current_time, self._timeout):
self._devs.reset_dev_state(task.assigned_dev())
self._finished_tasks.add(task._idx)
if all_finished:
break
[docs] def reset(self) -> None:
"""Reset runner, clear task queue and device states"""
self._devs.reset_all()
self._finished_tasks = set()
self._queue = []
[docs] def pull(self, ftask_proc: typing.Callable, fret_proc: typing.Callable) -> List:
"""Pull results from all tasks executed on the runner.
Parameters
----------
ftask_proc : Callable
Function to process each task's output
fret_proc : Callable
Function to extract returns from task
Returns
-------
List
Aggregated returns from all tasks
"""
ret = []
for task in self._queue:
task.pull(ftask_proc)
if task.is_finished():
if task._ret is not None:
ret.append(fret_proc(task))
self.reset()
return ret
[docs] def push(self, idx: typing.Union[int, str], cmd: str):
"""Push a task into runner
Parameters
----------
idx : Union[int, str]
id of the task
cmd : str
bash command line for the task
"""
raise NotImplementedError