This commit is contained in:
john 2021-07-05 16:45:06 -04:00
parent 7b6f56ec4e
commit 3473cc262e
18 changed files with 2171 additions and 1 deletions

11
.editorconfig Normal file
View File

@ -0,0 +1,11 @@
# https://editorconfig.org/
root = true
[*]
indent_style = space
indent_size = 4
insert_final_newline = true
trim_trailing_whitespace = true
end_of_line = lf
charset = utf-8

24
LICENSE.txt Normal file
View File

@ -0,0 +1,24 @@
MIT License
Copyright (c) 2021 Johnathan Nader
Copyright (c) 2020 Lucas Nestler
Copyright (c) 2020 Dr. Tim Scarfe
Copyright (c) 2020 Daniel Gatis
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

12
MANIFEST.in Normal file
View File

@ -0,0 +1,12 @@
include pyproject.toml
# Include the README
include *.md
# Include the license file
include LICENSE.txt
# Include the data files
recursive-include data *
include requirements.txt

165
README.md
View File

@ -1 +1,164 @@
# backgroundremover
# BackgroundRemover
A command line tool to remove background from [video](https://backgroundremover.app/video)
and [image](https://backgroundremover.app/image), brought to you
by [BackgroundRemover.app](https://backgroundremover.app) which is an app made by [nadermx](https://john.nader.mx) powered by this tool
<img alt="background remover video" src="https://backgroundremover.app/static/backgroundremover.gif" height="200" />
<img alt="green screen matte key file" src="https://backgroundremover.app/static/matte.gif" height="200" width="110" />
<img alt="background remover image" src="https://backgroundremover.app/static/backgroundremoverexample.png" height="200"/>
### Requirements
* python 3.6 (only one tested so far but may work for < 3.6)
* python3.6-dev
* torch and torchvision stable version (https://pytorch.org)
* ffmpeg 4.2+
#### How to install torch and fmpeg
Go to https://pytorch.org and scroll down to `INSTALL PYTORCH` section and follow the instructions.
For example:
```
PyTorch Build: Stable (1.7.1)
Your OS: Windows
Package: Pip
Language: Python
CUDA: None
```
To install ffmpeg
```
sudo apt install ffmpeg python3.6-dev
```
To install torch:
```
pip install --upgrade pip
pip install torch==1.7.1+cpu torchvision==0.8.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
```
### Installation
To Install backgroundremover, install it from pypi
```bash
pip install backgroundremover
```
# Usage as a cli
## Image
Remove the background from a local file image
```bash
backgroundremover -i "/path/to/image.jpeg" -o "output.png"
```
### Advance usage for image background removal
Sometimes it is possible to achieve better results by turning on alpha matting. Example:
```bash
backgroundremover -i "/path/to/image.jpeg" a -ae 15 -o "output.png"
```
change the model for diferent background removal methods between `u2netp`, `u2net`, or `u2net_human_seg`
```bash
backgroundremover -i "/path/to/image.jpeg" -m "u2net_human_seg" -o "output.png"
```
## Video
### remove background from video and make transparent mov
```bash
backgroundremover -i "/path/to/video.mp4" -tv -o "output.mov"
```
###remove background from local video and overlay it over other video
```bash
backgroundremover -i "/path/to/video.mp4" -tov -tv "/path/to/videtobeoverlayed.mp4" -o "output.mov"
```
### remove background from video and make transparent gif
```bash
backgroundremover -i "/path/to/video.mp4" -tg -o "output.gif"
```
### Make matte key file (green screen overlay)
Make a matte file for premier
```bash
backgroundremover -i "/path/to/video.mp4" -mk -o "output.matte.mp4"
```
### Advance usage for video
Change the framerate of the video (default is set to 30)
```bash
backgroundremover -i "/path/to/video.mp4" -fr 30 -tv -o "output.mov"
```
Change the gpu batch size of the video (default is set to 1)
```bash
backgroundremover -i "/path/to/video.mp4" -gp 4 -tv -o "output.mov"
```
Change the number of workers working on video (default is set to 1)
```bash
backgroundremover -i "/path/to/video.mp4" -wn 4 -tv -o "output.mov"
```
change the model for diferent background removal methods between `u2netp`, `u2net`, or `u2net_human_seg`
```bash
backgroundremover -i "/path/to/video.mp4" -m "u2net_human_seg"-tv -o "output.mov"
```
## Todo
- convert logic from video to image to utilize more GPU on image removal
- remove duplicate imports from image and video of u2net models
- clean up documentation a bit more
- add ability to adjust and give feedback images or videos to datasets
- other
### Pull requests
Accepted
### If you like this library
Give a link to our project [BackgroundRemover.app](https://backgroundremover.app) or this git, telling people that you like it or use it.
### Reason for project
We made it our own package after merging together parts of others, adding in a few features of our own via posting parts as bounty questions on superuser, etc. As well as asked on hackernews earlier to open source the image part, so decided to add in video, and a bit more.
### References
- https://arxiv.org/pdf/2005.09007.pdf
- https://github.com/NathanUA/U-2-Net
- https://github.com/pymatting/pymatting
- https://github.com/danielgatis/rembg
- https://github.com/ecsplendid/rembg-greenscreen
- https://superuser.com/questions/1647590/have-ffmpeg-merge-a-matte-key-file-over-the-normal-video-file-removing-the-backg
- https://superuser.com/questions/1648680/ffmpeg-alphamerge-two-videos-into-a-gif-with-transparent-background/1649339?noredirect=1#comment2522687_1649339
- https://superuser.com/questions/1649817/ffmpeg-overlay-a-video-after-alphamerging-two-others/1649856#1649856
### License
- Copyright (c) 2021-present [Johnathan Nader](https://github.com/nadermx)
- Copyright (c) 2020-present [Lucas Nestler](https://github.com/ClashLuke)
- Copyright (c) 2020-present [Dr. Tim Scarfe](https://github.com/ecsplendid)
- Copyright (c) 2020-present [Daniel Gatis](https://github.com/danielgatis)
Licensed under [MIT License](./LICENSE.txt)

5
pyproject.toml Normal file
View File

@ -0,0 +1,5 @@
[build-system]
# These are the assumed default build requirements from pip:
# https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support
requires = ["setuptools>=40.8.0", "wheel"]
build-backend = "setuptools.build_meta"

16
requirements.txt Normal file
View File

@ -0,0 +1,16 @@
numpy>=1.19.4
scikit-image>=0.17.2
torch>=1.7.0
torchvision>=0.8.1
waitress>=1.4.4
tqdm>=4.51.0
requests>=2.24.0
scipy>=1.5.4
pymatting>=1.1.1
filetype>=1.0.7
hsh>=1.1.0
more_itertools==8.7.0
moviepy==1.0.3
Pillow==8.1.1
ffmpeg-python

4
setup.cfg Normal file
View File

@ -0,0 +1,4 @@
[metadata]
# This includes the license file(s) in the wheel.
# https://wheel.readthedocs.io/en/stable/user_guide.html#including-license-files-in-the-generated-wheel-file
license_files = LICENSE.txt

35
setup.py Normal file
View File

@ -0,0 +1,35 @@
import pathlib
from setuptools import find_packages, setup
here = pathlib.Path(__file__).parent.resolve()
long_description = (here / "README.md").read_text(encoding="utf-8")
with open("requirements.txt") as f:
requireds = f.read().splitlines()
setup(
name="backgroundremover",
version="0.1.1",
description="Background remover from image and video",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/nadermx/backgroundremover",
author="Johnathan Nader",
author_email="john@nader.mx",
classifiers=[
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3 :: Only",
],
keywords="remove, background, u2net, remove background, background remover",
package_dir={"": "src"},
packages=find_packages(where="src"),
python_requires=">=3.6, <4",
install_requires=requireds,
entry_points={
"console_scripts": [
"backgroundremover=backgroundremover.cmd.cli:main",
],
},
)

View File

@ -0,0 +1,9 @@
"""
backgroundremover
A library to remove background from videos and images
"""
__version__ = "0.1.1"
__author__ = 'Johnathan Nader'
__credits__ = 'BackgroundRemover.app'

201
src/backgroundremover/bg.py Normal file
View File

@ -0,0 +1,201 @@
import functools
import io
import os
import typing
from PIL import Image
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
from pymatting.util.util import stack_images
from scipy.ndimage.morphology import binary_erosion
import moviepy.editor as mpy
import numpy as np
import requests
import torch
import torch.nn.functional
import torch.nn.functional
from hsh.library.hash import Hasher
from tqdm import tqdm
from .u2net import detect, u2net
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class Net(torch.nn.Module):
def __init__(self, model_name):
super(Net, self).__init__()
hasher = Hasher()
model, hash_val, drive_target, env_var = {
'u2netp': (u2net.U2NETP,
'e4f636406ca4e2af789941e7f139ee2e',
'1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy',
'U2NET_PATH'),
'u2net': (u2net.U2NET,
'09fb4e49b7f785c9f855baf94916840a',
'1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P',
'U2NET_PATH'),
'u2net_human_seg': (u2net.U2NET,
'347c3d51b01528e5c6c071e3cff1cb55',
'1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ',
'U2NET_PATH')
}[model_name]
path = os.environ.get(env_var, os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")))
net = model(3, 1)
if not os.path.exists(path) or hasher.md5(path) != hash_val:
head, tail = os.path.split(path)
os.makedirs(head, exist_ok=True)
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params={"id": drive_target}, stream=True)
token = None
for key, value in response.cookies.items():
if key.startswith("download_warning"):
token = value
break
if token:
params = {"id": drive_target, "confirm": token}
response = session.get(URL, params=params, stream=True)
total = int(response.headers.get("content-length", 0))
with open(path, "wb") as file, tqdm(
desc=f"Downloading {tail} to {head}",
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
net.load_state_dict(torch.load(path, map_location=torch.device(DEVICE)))
net.to(device=DEVICE, dtype=torch.float32, non_blocking=True)
net.eval()
self.net = net
def forward(self, block_input: torch.Tensor):
image_data = block_input.permute(0, 3, 1, 2)
original_shape = image_data.shape[2:]
image_data = torch.nn.functional.interpolate(image_data, (320, 320), mode='bilinear')
image_data = (image_data / 255 - 0.485) / 0.229
out = self.net(image_data)[0][:, 0:1]
ma = torch.max(out)
mi = torch.min(out)
out = (out - mi) / (ma - mi) * 255
out = torch.nn.functional.interpolate(out, original_shape, mode='bilinear')
out = out[:, 0]
out = out.to(dtype=torch.uint8, device=torch.device('cpu'), non_blocking=True).detach()
return out
def alpha_matting_cutout(
img,
mask,
foreground_threshold,
background_threshold,
erode_structure_size,
base_size,
):
size = img.size
img.thumbnail((base_size, base_size), Image.LANCZOS)
mask = mask.resize(img.size, Image.LANCZOS)
img = np.asarray(img)
mask = np.asarray(mask)
# guess likely foreground/background
is_foreground = mask > foreground_threshold
is_background = mask < background_threshold
# erode foreground/background
structure = None
if erode_structure_size > 0:
structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int)
is_foreground = binary_erosion(is_foreground, structure=structure)
is_background = binary_erosion(is_background, structure=structure, border_value=1)
# build trimap
# 0 = background
# 128 = unknown
# 255 = foreground
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
trimap[is_foreground] = 255
trimap[is_background] = 0
# build the cutout image
img_normalized = img / 255.0
trimap_normalized = trimap / 255.0
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
foreground = estimate_foreground_ml(img_normalized, alpha)
cutout = stack_images(foreground, alpha)
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
cutout = Image.fromarray(cutout)
cutout = cutout.resize(size, Image.LANCZOS)
return cutout
def naive_cutout(img, mask):
empty = Image.new("RGBA", (img.size), 0)
cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
return cutout
@functools.lru_cache(maxsize=None)
def get_model(model_name):
if model_name == "u2netp":
return detect.load_model(model_name="u2netp")
if model_name == "u2net_human_seg":
return detect.load_model(model_name="u2net_human_seg")
else:
return detect.load_model(model_name="u2net")
def remove(
data,
model_name="u2net",
alpha_matting=False,
alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10,
alpha_matting_erode_structure_size=10,
alpha_matting_base_size=1000,
):
model = get_model(model_name)
img = Image.open(io.BytesIO(data)).convert("RGB")
mask = detect.predict(model, np.array(img)).convert("L")
if alpha_matting:
cutout = alpha_matting_cutout(
img,
mask,
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_structure_size,
alpha_matting_base_size,
)
else:
cutout = naive_cutout(img, mask)
bio = io.BytesIO()
cutout.save(bio, "PNG")
return bio.getbuffer()
def iter_frames(path):
return mpy.VideoFileClip(path).resize(height=320).iter_frames(dtype="uint8")
@torch.no_grad()
def remove_many(image_data: typing.List[np.array], net: Net):
image_data = np.stack(image_data)
image_data = torch.as_tensor(image_data, dtype=torch.float32, device=DEVICE)
return net(image_data).numpy()

View File

View File

@ -0,0 +1,264 @@
import argparse
import glob
import os
from distutils.util import strtobool
from ..bg import remove
from ..utilities import matte_key, transparentgif, transparentvideo, transparentvideoovervideo, transparentvideooverimage, \
transparentgifwithbackground
import torch
from .. import utilities
def main():
model_path = os.environ.get(
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net")),
)
model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
if len(model_choices) == 0:
model_choices = ["u2net", "u2netp", "u2net_human_seg"]
ap = argparse.ArgumentParser()
ap.add_argument(
"-m",
"--model",
default="u2net",
type=str,
choices=model_choices,
help="The model name, u2net, u2netp, u2net_human_seg",
)
ap.add_argument(
"-a",
"--alpha-matting",
nargs="?",
const=True,
default=False,
type=lambda x: bool(strtobool(x)),
help="When true use alpha matting cutout.",
)
ap.add_argument(
"-af",
"--alpha-matting-foreground-threshold",
default=240,
type=int,
help="The trimap foreground threshold.",
)
ap.add_argument(
"-ab",
"--alpha-matting-background-threshold",
default=10,
type=int,
help="The trimap background threshold.",
)
ap.add_argument(
"-ae",
"--alpha-matting-erode-size",
default=10,
type=int,
help="Size of element used for the erosion.",
)
ap.add_argument(
"-az",
"--alpha-matting-base-size",
default=1000,
type=int,
help="The image base size.",
)
ap.add_argument(
"-wn",
"--workernodes",
default=1,
type=int,
help="Number of parallel workers"
)
ap.add_argument(
"-gb",
"--gpubatchsize",
default=2,
type=int,
help="GPU batchsize"
)
ap.add_argument(
"-fr",
"--framerate",
default=-1,
type=int,
help="Override the frame rate"
)
ap.add_argument(
"-fl",
"--framelimit",
default=-1,
type=int,
help="Limit the number of frames to process for quick testing.",
)
ap.add_argument(
"-mk",
"--mattekey",
nargs="?",
const=True,
default=False,
type=lambda x: bool(strtobool(x)),
help="Output the Matte key file",
)
ap.add_argument(
"-tv",
"--transparentvideo",
nargs="?",
const=True,
default=False,
type=lambda x: bool(strtobool(x)),
help="Output transparent video format mov",
)
ap.add_argument(
"-tov",
"--transparentvideoovervideo",
nargs="?",
const=True,
default=False,
type=lambda x: bool(strtobool(x)),
help="Overlay transparent video over another video",
)
ap.add_argument(
"-toi",
"--transparentvideooverimage",
nargs="?",
const=True,
default=False,
type=lambda x: bool(strtobool(x)),
help="Overlay transparent video over another video",
)
ap.add_argument(
"-tg",
"--transparentgif",
nargs="?",
const=True,
default=False,
type=lambda x: bool(strtobool(x)),
help="Make transparent gif from video",
)
ap.add_argument(
"-tgwb",
"--transparentgifwithbackground",
nargs="?",
const=True,
default=False,
type=lambda x: bool(strtobool(x)),
help="Make transparent background overlay a background image",
)
ap.add_argument(
"-i",
"--input",
nargs="?",
default="-",
type=argparse.FileType("rb"),
help="Path to the input video or image.",
)
ap.add_argument(
"-bi",
"--backgroundimage",
nargs="?",
default="-",
type=argparse.FileType("rb"),
help="Path to background image.",
)
ap.add_argument(
"-bv",
"--backgroundvideo",
nargs="?",
default="-",
type=argparse.FileType("rb"),
help="Path to background video.",
)
ap.add_argument(
"-o",
"--output",
nargs="?",
default="-",
type=argparse.FileType("wb"),
help="Path to the output",
)
args = ap.parse_args()
if args.input.name.rsplit('.', 1)[1] in ['mp4', 'mov', 'webm', 'ogg', 'gif']:
if args.mattekey:
matte_key(os.path.abspath(args.output.name), os.path.abspath(args.input.name),
worker_nodes=args.workernodes,
gpu_batchsize=args.gpubatchsize,
model_name=args.model,
frame_limit=args.framelimit,
framerate=args.framerate)
elif args.transparentvideo:
transparentvideo(os.path.abspath(args.output.name), os.path.abspath(args.input.name),
worker_nodes=args.workernodes,
gpu_batchsize=args.gpubatchsize,
model_name=args.model,
frame_limit=args.framelimit,
framerate=args.framerate)
elif args.transparentvideoovervideo:
transparentvideoovervideo(os.path.abspath(args.output.name), os.path.abspath(args.backgroundvideo.name),
os.path.abspath(args.input.name),
worker_nodes=args.workernodes,
gpu_batchsize=args.gpubatchsize,
model_name=args.model,
frame_limit=args.framelimit,
framerate=args.framerate)
elif args.transparentvideooverimage:
transparentvideooverimage(os.path.abspath(args.output.name), os.path.abspath(args.backgroundimage.name),
os.path.abspath(args.input.name),
worker_nodes=args.workernodes,
gpu_batchsize=args.gpubatchsize,
model_name=args.model,
frame_limit=args.framelimit,
framerate=args.framerate)
elif args.transparentgif:
transparentgif(os.path.abspath(args.output.name), os.path.abspath(args.input.name),
worker_nodes=args.workernodes,
gpu_batchsize=args.gpubatchsize,
model_name=args.model,
frame_limit=args.framelimit,
framerate=args.framerate)
elif args.transparentgifwithbackground:
transparentgifwithbackground(os.path.abspath(args.output.name), os.path.abspath(args.backgroundimage.name), os.path.abspath(args.input.name),
worker_nodes=args.workernodes,
gpu_batchsize=args.gpubatchsize,
model_name=args.model,
frame_limit=args.framelimit,
framerate=args.framerate)
else:
print(args.output.name)
r = lambda i: i.buffer.read() if hasattr(i, "buffer") else i.read()
w = lambda o, data: o.buffer.write(data) if hasattr(o, "buffer") else o.write(data)
w(
args.output,
remove(
r(args.input),
model_name=args.model,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
alpha_matting_base_size=args.alpha_matting_base_size,
),
)
if __name__ == "__main__":
torch.multiprocessing.set_start_method('spawn')
main()

View File

@ -0,0 +1,98 @@
import os
import glob
import argparse
from io import BytesIO
from urllib.parse import unquote_plus
from urllib.request import urlopen
from flask import Flask, request, send_file
from waitress import serve
from ..bg import remove
app = Flask(__name__)
@app.route("/", methods=["GET", "POST"])
def index():
file_content = ""
if request.method == "POST":
if "file" not in request.files:
return {"error": "missing post form param 'file'"}, 400
file_content = request.files["file"].read()
if request.method == "GET":
url = request.args.get("url", type=str)
if url is None:
return {"error": "missing query param 'url'"}, 400
file_content = urlopen(unquote_plus(url)).read()
if file_content == "":
return {"error": "File content is empty"}, 400
alpha_matting = "a" in request.values
af = request.values.get("af", type=int, default=240)
ab = request.values.get("ab", type=int, default=10)
ae = request.values.get("ae", type=int, default=10)
az = request.values.get("az", type=int, default=1000)
model = request.args.get("model", type=str, default="u2net")
model_path = os.environ.get(
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net")),
)
model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
if len(model_choices) == 0:
model_choices = ["u2net", "u2netp", "u2net_human_seg"]
if model not in model_choices:
return {"error": f"invalid query param 'model'. Available options are {model_choices}"}, 400
try:
return send_file(
BytesIO(
remove(
file_content,
model_name=model,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=af,
alpha_matting_background_threshold=ab,
alpha_matting_erode_structure_size=ae,
alpha_matting_base_size=az,
)
),
mimetype="image/png",
)
except Exception as e:
app.logger.exception(e, exc_info=True)
return {"error": "oops, something went wrong!"}, 500
def main():
ap = argparse.ArgumentParser()
ap.add_argument(
"-a",
"--addr",
default="0.0.0.0",
type=str,
help="The IP address to bind to.",
)
ap.add_argument(
"-p",
"--port",
default=5000,
type=int,
help="The port to bind to.",
)
args = ap.parse_args()
serve(app, host=args.addr, port=args.port)
if __name__ == "__main__":
main()

View File

View File

@ -0,0 +1,324 @@
# data loader
from __future__ import division, print_function
import random
import numpy as np
import torch
from skimage import color, io, transform
from torch.utils.data import DataLoader, Dataset
# ==========================dataset load==========================
class RescaleT(object):
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
# img = transform.resize(image,(new_h,new_w),mode='constant')
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
img = transform.resize(
image, (self.output_size, self.output_size), mode="constant"
)
lbl = transform.resize(
label,
(self.output_size, self.output_size),
mode="constant",
order=0,
preserve_range=True,
)
return {"imidx": imidx, "image": img, "label": lbl}
class Rescale(object):
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
img = transform.resize(image, (new_h, new_w), mode="constant")
lbl = transform.resize(
label, (new_h, new_w), mode="constant", order=0, preserve_range=True
)
return {"imidx": imidx, "image": img, "label": lbl}
class RandomCrop(object):
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top : top + new_h, left : left + new_w]
label = label[top : top + new_h, left : left + new_w]
return {"imidx": imidx, "image": image, "label": label}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
tmpLbl = np.zeros(label.shape)
image = image / np.max(image)
if np.max(label) < 1e-6:
label = label
else:
label = label / np.max(label)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
tmpLbl[:, :, 0] = label[:, :, 0]
# change the r,g,b to b,r,g from [0,255] to [0,1]
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {
"imidx": torch.from_numpy(imidx),
"image": torch.from_numpy(tmpImg),
"label": torch.from_numpy(tmpLbl),
}
class ToTensorLab(object):
"""Convert ndarrays in sample to Tensors."""
def __init__(self, flag=0):
self.flag = flag
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
tmpLbl = np.zeros(label.shape)
if np.max(label) < 1e-6:
label = label
else:
label = label / np.max(label)
# change the color space
if self.flag == 2: # with rgb and Lab colors
tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
if image.shape[2] == 1:
tmpImgt[:, :, 0] = image[:, :, 0]
tmpImgt[:, :, 1] = image[:, :, 0]
tmpImgt[:, :, 2] = image[:, :, 0]
else:
tmpImgt = image
tmpImgtl = color.rgb2lab(tmpImgt)
# nomalize image to range [0,1]
tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])
)
tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])
)
tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])
)
tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])
)
tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])
)
tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])
)
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
tmpImg[:, :, 0]
)
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
tmpImg[:, :, 1]
)
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
tmpImg[:, :, 2]
)
tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(
tmpImg[:, :, 3]
)
tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(
tmpImg[:, :, 4]
)
tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(
tmpImg[:, :, 5]
)
elif self.flag == 1: # with Lab color
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
if image.shape[2] == 1:
tmpImg[:, :, 0] = image[:, :, 0]
tmpImg[:, :, 1] = image[:, :, 0]
tmpImg[:, :, 2] = image[:, :, 0]
else:
tmpImg = image
tmpImg = color.rgb2lab(tmpImg)
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])
)
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])
)
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])
)
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
tmpImg[:, :, 0]
)
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
tmpImg[:, :, 1]
)
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
tmpImg[:, :, 2]
)
else: # with rgb color
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
image = image / np.max(image)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
tmpLbl[:, :, 0] = label[:, :, 0]
# change the r,g,b to b,r,g from [0,255] to [0,1]
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {
"imidx": torch.from_numpy(imidx),
"image": torch.from_numpy(tmpImg),
"label": torch.from_numpy(tmpLbl),
}
class SalObjDataset(Dataset):
def __init__(self, img_name_list, lbl_name_list, transform=None):
# self.root_dir = root_dir
# self.image_name_list = glob.glob(image_dir+'*.png')
# self.label_name_list = glob.glob(label_dir+'*.png')
self.image_name_list = img_name_list
self.label_name_list = lbl_name_list
self.transform = transform
def __len__(self):
return len(self.image_name_list)
def __getitem__(self, idx):
# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
image = io.imread(self.image_name_list[idx])
imname = self.image_name_list[idx]
imidx = np.array([idx])
if 0 == len(self.label_name_list):
label_3 = np.zeros(image.shape)
else:
label_3 = io.imread(self.label_name_list[idx])
label = np.zeros(label_3.shape[0:2])
if 3 == len(label_3.shape):
label = label_3[:, :, 0]
elif 2 == len(label_3.shape):
label = label_3
if 3 == len(image.shape) and 2 == len(label.shape):
label = label[:, :, np.newaxis]
elif 2 == len(image.shape) and 2 == len(label.shape):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
sample = {"imidx": imidx, "image": image, "label": label}
if self.transform:
sample = self.transform(sample)
return sample

View File

@ -0,0 +1,178 @@
import errno
import os
import sys
import numpy as np
import requests
import torch
from hsh.library.hash import Hasher
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from . import data_loader, u2net
def download_file_from_google_drive(id, fname, destination):
head, tail = os.path.split(destination)
os.makedirs(head, exist_ok=True)
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params={"id": id}, stream=True)
token = None
for key, value in response.cookies.items():
if key.startswith("download_warning"):
token = value
break
if token:
params = {"id": id, "confirm": token}
response = session.get(URL, params=params, stream=True)
total = int(response.headers.get("content-length", 0))
with open(destination, "wb") as file, tqdm(
desc=f"Downloading {tail} to {head}",
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
def load_model(model_name: str = "u2net"):
hasher = Hasher()
if model_name == "u2netp":
net = u2net.U2NETP(3, 1)
path = os.environ.get(
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"
):
download_file_from_google_drive(
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
"u2netp.pth",
path,
)
elif model_name == "u2net":
net = u2net.U2NET(3, 1)
path = os.environ.get(
"U2NET_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
):
download_file_from_google_drive(
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
"u2net.pth",
path,
)
elif model_name == "u2net_human_seg":
net = u2net.U2NET(3, 1)
path = os.environ.get(
"U2NET_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
):
download_file_from_google_drive(
"1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
"u2net_human_seg.pth",
path,
)
else:
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
try:
if torch.cuda.is_available():
net.load_state_dict(torch.load(path))
net.to(torch.device("cuda"))
else:
net.load_state_dict(
torch.load(
path,
map_location="cpu",
)
)
except FileNotFoundError:
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
)
net.eval()
return net
def norm_pred(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d - mi) / (ma - mi)
return dn
def preprocess(image):
label_3 = np.zeros(image.shape)
label = np.zeros(label_3.shape[0:2])
if 3 == len(label_3.shape):
label = label_3[:, :, 0]
elif 2 == len(label_3.shape):
label = label_3
if 3 == len(image.shape) and 2 == len(label.shape):
label = label[:, :, np.newaxis]
elif 2 == len(image.shape) and 2 == len(label.shape):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
transform = transforms.Compose(
[data_loader.RescaleT(320), data_loader.ToTensorLab(flag=0)]
)
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
return sample
def predict(net, item):
sample = preprocess(item)
with torch.no_grad():
if torch.cuda.is_available():
inputs_test = torch.cuda.FloatTensor(
sample["image"].unsqueeze(0).cuda().float()
)
else:
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
pred = d1[:, 0, :, :]
predict = norm_pred(pred)
predict = predict.squeeze()
predict_np = predict.cpu().detach().numpy()
img = Image.fromarray(predict_np * 255).convert("RGB")
del d1, d2, d3, d4, d5, d6, d7, pred, predict, predict_np, inputs_test, sample
return img

View File

@ -0,0 +1,541 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class REBNCONV(nn.Module):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super(REBNCONV, self).__init__()
self.conv_s1 = nn.Conv2d(
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self, x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src, tar):
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
return src
### RSU-7 ###
class RSU7(nn.Module): # UNet07DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU7, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
hx6dup = _upsample_like(hx6d, hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module): # UNet06DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module): # UNet05DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module): # UNet04DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
return hx1d + hxin
##### U^2-Net ####
class U2NET(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super(U2NET, self).__init__()
self.stage1 = RSU7(in_ch, 32, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 32, 128)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(128, 64, 256)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(256, 128, 512)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(512, 256, 512)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(512, 256, 512)
# decoder
self.stage5d = RSU4F(1024, 256, 512)
self.stage4d = RSU4(1024, 128, 256)
self.stage3d = RSU5(512, 64, 128)
self.stage2d = RSU6(256, 32, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
self.outconv = nn.Conv2d(6, out_ch, 1)
def forward(self, x):
hx = x
# stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# -------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
return (
torch.sigmoid(d0),
torch.sigmoid(d1),
torch.sigmoid(d2),
torch.sigmoid(d3),
torch.sigmoid(d4),
torch.sigmoid(d5),
torch.sigmoid(d6),
)
### U^2-Net small ###
class U2NETP(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super(U2NETP, self).__init__()
self.stage1 = RSU7(in_ch, 16, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 16, 64)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(64, 16, 64)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(64, 16, 64)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(64, 16, 64)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(64, 16, 64)
# decoder
self.stage5d = RSU4F(128, 16, 64)
self.stage4d = RSU4(128, 16, 64)
self.stage3d = RSU5(128, 16, 64)
self.stage2d = RSU6(128, 16, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
self.outconv = nn.Conv2d(6, out_ch, 1)
def forward(self, x):
hx = x
# stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# decoder
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
return (
torch.sigmoid(d0),
torch.sigmoid(d1),
torch.sigmoid(d2),
torch.sigmoid(d3),
torch.sigmoid(d4),
torch.sigmoid(d5),
torch.sigmoid(d6),
)

View File

@ -0,0 +1,285 @@
import os
import math
import multiprocessing
import subprocess as sp
import time
import ffmpeg
import numpy as np
import torch
import tempfile
from .bg import DEVICE, Net, iter_frames, remove_many
import shlex
def worker(worker_nodes,
worker_index,
result_dict,
model_name,
gpu_batchsize,
total_frames,
frames_dict):
print(F"WORKER {worker_index} ONLINE")
output_index = worker_index + 1
base_index = worker_index * gpu_batchsize
net = Net(model_name)
script_net = None
for fi in (list(range(base_index + i * worker_nodes * gpu_batchsize,
min(base_index + i * worker_nodes * gpu_batchsize + gpu_batchsize, total_frames)))
for i in range(math.ceil(total_frames / worker_nodes / gpu_batchsize))):
if not fi:
break
# are we processing frames faster than the frame ripper is saving them?
last = fi[-1]
while last not in frames_dict:
time.sleep(0.1)
input_frames = [frames_dict[index] for index in fi]
if script_net is None:
script_net = torch.jit.trace(net,
torch.as_tensor(np.stack(input_frames), dtype=torch.float32, device=DEVICE))
result_dict[output_index] = remove_many(input_frames, script_net)
# clean up the frame buffer
for fdex in fi:
del frames_dict[fdex]
output_index += worker_nodes
def capture_frames(file_path, frames_dict, prefetched_samples, total_frames):
print(F"WORKER FRAMERIPPER ONLINE")
for idx, frame in enumerate(iter_frames(file_path)):
frames_dict[idx] = frame
while len(frames_dict) > prefetched_samples:
time.sleep(0.1)
if idx > total_frames:
break
def matte_key(output, file_path,
worker_nodes,
gpu_batchsize,
model_name,
frame_limit=-1,
prefetched_batches=4,
framerate=-1):
manager = multiprocessing.Manager()
results_dict = manager.dict()
frames_dict = manager.dict()
print(file_path)
info = ffmpeg.probe(file_path)
total_frames = int(info["streams"][0]["nb_frames"])
if frame_limit != -1:
total_frames = min(frame_limit, total_frames)
fr = info["streams"][0]["r_frame_rate"]
if framerate == -1:
print(F"FRAME RATE DETECTED: {fr} (if this looks wrong, override the frame rate)")
framerate = math.ceil(eval(fr))
print(F"FRAME RATE: {framerate} TOTAL FRAMES: {total_frames}")
p = multiprocessing.Process(target=capture_frames,
args=(file_path, frames_dict, gpu_batchsize * prefetched_batches, total_frames))
p.start()
# note I am deliberatley not using pool
# we can't trust it to run all the threads concurrently (or at all)
workers = [multiprocessing.Process(target=worker,
args=(worker_nodes, wn, results_dict, model_name, gpu_batchsize, total_frames,
frames_dict))
for wn in range(worker_nodes)]
for w in workers:
w.start()
command = None
proc = None
frame_counter = 0
for i in range(math.ceil(total_frames / worker_nodes)):
for wx in range(worker_nodes):
hash_index = i * worker_nodes + 1 + wx
while hash_index not in results_dict:
time.sleep(0.1)
frames = results_dict[hash_index]
# dont block access to it anymore
del results_dict[hash_index]
for frame in frames:
if command is None:
command = ['nice', '-10',
'ffmpeg',
'-y',
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-s', F"{frame.shape[1]}x320",
'-pix_fmt', 'gray',
'-r', F"{framerate}",
'-i', '-',
'-an',
'-vcodec', 'mpeg4',
'-b:v', '2000k',
'%s' % output]
proc = sp.Popen(command, stdin=sp.PIPE)
proc.stdin.write(frame.tostring())
frame_counter = frame_counter + 1
if frame_counter >= total_frames:
p.join()
for w in workers:
w.join()
proc.stdin.close()
proc.wait()
print(F"FINISHED ALL FRAMES ({total_frames})!")
return
p.join()
for w in workers:
w.join()
proc.stdin.close()
proc.wait()
return
def transparentgif(output, file_path,
worker_nodes,
gpu_batchsize,
model_name,
frame_limit=-1,
prefetched_batches=4,
framerate=-1):
with tempfile.TemporaryDirectory() as tmpdirname:
temp_file = os.path.abspath("%s/matte.mp4" % tmpdirname)
matte_key(temp_file, file_path,
worker_nodes,
gpu_batchsize,
model_name,
frame_limit,
prefetched_batches,
framerate)
cmd = "nice -10 ffmpeg -y -i %s -i %s -filter_complex '[1][0]scale2ref[mask][main];[main][mask]alphamerge=shortest=1,fps=10,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse' -shortest %s" % (
file_path, temp_file, output)
sp.run(shlex.split(cmd))
print("Process finished")
return
def transparentgifwithbackground(output, overlay, file_path,
worker_nodes,
gpu_batchsize,
model_name,
frame_limit=-1,
prefetched_batches=4,
framerate=-1):
with tempfile.TemporaryDirectory() as tmpdirname:
temp_file = os.path.abspath("%s/matte.mp4" % tmpdirname)
matte_key(temp_file, file_path,
worker_nodes,
gpu_batchsize,
model_name,
frame_limit,
prefetched_batches,
framerate)
print("Starting alphamerge")
cmd = "nice -10 ffmpeg -y -i %s -i %s -i %s -filter_complex '[1][0]scale2ref[mask][main];[main][mask]alphamerge=shortest=1[fg];[2][fg]overlay=(main_w-overlay_w)/2:(main_h-overlay_h)/2:format=auto,fps=10,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse' -shortest %s" % (
file_path, temp_file, overlay, output)
sp.run(shlex.split(cmd))
print("Process finished")
return
def transparentvideo(output, file_path,
worker_nodes,
gpu_batchsize,
model_name,
frame_limit=-1,
prefetched_batches=4,
framerate=-1):
with tempfile.TemporaryDirectory() as tmpdirname:
temp_file = os.path.abspath("%s/matte.mp4" % tmpdirname)
matte_key(temp_file, file_path,
worker_nodes,
gpu_batchsize,
model_name,
frame_limit,
prefetched_batches,
framerate)
print("Starting alphamerge")
cmd = "nice -10 ffmpeg -y -nostats -loglevel 0 -i %s -i %s -filter_complex '[1][0]scale2ref[mask][main];[main][mask]alphamerge=shortest=1' -c:v qtrle -shortest %s" % (
file_path, temp_file, output)
process = sp.Popen(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
stdout, stderr = process.communicate()
print('after call')
if stderr:
return "ERROR: %s" % stderr.decode("utf-8")
print("Process finished")
return
def transparentvideoovervideo(output, overlay, file_path,
worker_nodes,
gpu_batchsize,
model_name,
frame_limit=-1,
prefetched_batches=4,
framerate=-1):
with tempfile.TemporaryDirectory() as tmpdirname:
temp_file = os.path.abspath("%s/matte.mp4" % tmpdirname)
matte_key(temp_file, file_path,
worker_nodes,
gpu_batchsize,
model_name,
frame_limit,
prefetched_batches,
framerate)
print("Starting alphamerge")
cmd = "nice -10 ffmpeg -y -i %s -i %s -i %s -filter_complex '[1][0]scale2ref[mask][main];[main][mask]alphamerge=shortest=1[vid];[vid][2:v]scale2ref[fg][bg];[bg][fg]overlay=shortest=1[out]' -map [out] -shortest %s" % (
file_path, temp_file, overlay, output)
sp.run(shlex.split(cmd))
print("Process finished")
return
def transparentvideooverimage(output, overlay, file_path,
worker_nodes,
gpu_batchsize,
model_name,
frame_limit=-1,
prefetched_batches=4,
framerate=-1):
with tempfile.TemporaryDirectory() as tmpdirname:
temp_file = os.path.abspath("%s/matte.mp4" % tmpdirname)
matte_key(temp_file, file_path,
worker_nodes,
gpu_batchsize,
model_name,
frame_limit,
prefetched_batches,
framerate)
print("Scale image")
temp_image = os.path.abspath("%s/new.jpg" % tmpdirname)
cmd = "nice -10 ffmpeg -i %s -i %s -filter_complex 'scale2ref[img][vid];[img]setsar=1;[vid]nullsink' -q:v 2 %s" % (
overlay, file_path, temp_image)
sp.run(shlex.split(cmd))
print("Starting alphamerge")
cmd = "nice -10 ffmpeg -y -i %s -i %s -i %s -filter_complex '[0][1]scale2ref[img][vid];[img]setsar=1[img];[vid]nullsink; [img][2]overlay=(W-w)/2:(H-h)/2' -shortest %s" % (
#cmd = "nice -10 ffmpeg -y -i %s -i %s -i %s -filter_complex '[1][0]scale2ref[mask][main];[main][mask]alphamerge=shortest=1[vid];[2:v][vid]overlay[out]' -map [out] -shortest %s" % (
temp_image, file_path, temp_file, output)
sp.run(shlex.split(cmd))
print("Process finished")
return