initial
This commit is contained in:
parent
7b6f56ec4e
commit
3473cc262e
11
.editorconfig
Normal file
11
.editorconfig
Normal file
@ -0,0 +1,11 @@
|
||||
# https://editorconfig.org/
|
||||
|
||||
root = true
|
||||
|
||||
[*]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
insert_final_newline = true
|
||||
trim_trailing_whitespace = true
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
24
LICENSE.txt
Normal file
24
LICENSE.txt
Normal file
@ -0,0 +1,24 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 Johnathan Nader
|
||||
Copyright (c) 2020 Lucas Nestler
|
||||
Copyright (c) 2020 Dr. Tim Scarfe
|
||||
Copyright (c) 2020 Daniel Gatis
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
12
MANIFEST.in
Normal file
12
MANIFEST.in
Normal file
@ -0,0 +1,12 @@
|
||||
include pyproject.toml
|
||||
|
||||
# Include the README
|
||||
include *.md
|
||||
|
||||
# Include the license file
|
||||
include LICENSE.txt
|
||||
|
||||
# Include the data files
|
||||
recursive-include data *
|
||||
|
||||
include requirements.txt
|
165
README.md
165
README.md
@ -1 +1,164 @@
|
||||
# backgroundremover
|
||||
# BackgroundRemover
|
||||
|
||||
A command line tool to remove background from [video](https://backgroundremover.app/video)
|
||||
and [image](https://backgroundremover.app/image), brought to you
|
||||
by [BackgroundRemover.app](https://backgroundremover.app) which is an app made by [nadermx](https://john.nader.mx) powered by this tool
|
||||
|
||||
<img alt="background remover video" src="https://backgroundremover.app/static/backgroundremover.gif" height="200" />
|
||||
<img alt="green screen matte key file" src="https://backgroundremover.app/static/matte.gif" height="200" width="110" />
|
||||
<img alt="background remover image" src="https://backgroundremover.app/static/backgroundremoverexample.png" height="200"/>
|
||||
|
||||
### Requirements
|
||||
|
||||
* python 3.6 (only one tested so far but may work for < 3.6)
|
||||
* python3.6-dev
|
||||
* torch and torchvision stable version (https://pytorch.org)
|
||||
|
||||
* ffmpeg 4.2+
|
||||
|
||||
#### How to install torch and fmpeg
|
||||
|
||||
Go to https://pytorch.org and scroll down to `INSTALL PYTORCH` section and follow the instructions.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
PyTorch Build: Stable (1.7.1)
|
||||
Your OS: Windows
|
||||
Package: Pip
|
||||
Language: Python
|
||||
CUDA: None
|
||||
```
|
||||
|
||||
To install ffmpeg
|
||||
|
||||
```
|
||||
sudo apt install ffmpeg python3.6-dev
|
||||
```
|
||||
|
||||
To install torch:
|
||||
|
||||
```
|
||||
pip install --upgrade pip
|
||||
pip install torch==1.7.1+cpu torchvision==0.8.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
```
|
||||
|
||||
### Installation
|
||||
|
||||
To Install backgroundremover, install it from pypi
|
||||
|
||||
```bash
|
||||
pip install backgroundremover
|
||||
```
|
||||
|
||||
# Usage as a cli
|
||||
## Image
|
||||
|
||||
Remove the background from a local file image
|
||||
|
||||
```bash
|
||||
backgroundremover -i "/path/to/image.jpeg" -o "output.png"
|
||||
```
|
||||
|
||||
### Advance usage for image background removal
|
||||
|
||||
Sometimes it is possible to achieve better results by turning on alpha matting. Example:
|
||||
|
||||
```bash
|
||||
backgroundremover -i "/path/to/image.jpeg" a -ae 15 -o "output.png"
|
||||
```
|
||||
change the model for diferent background removal methods between `u2netp`, `u2net`, or `u2net_human_seg`
|
||||
```bash
|
||||
backgroundremover -i "/path/to/image.jpeg" -m "u2net_human_seg" -o "output.png"
|
||||
```
|
||||
## Video
|
||||
|
||||
### remove background from video and make transparent mov
|
||||
|
||||
```bash
|
||||
backgroundremover -i "/path/to/video.mp4" -tv -o "output.mov"
|
||||
```
|
||||
###remove background from local video and overlay it over other video
|
||||
```bash
|
||||
backgroundremover -i "/path/to/video.mp4" -tov -tv "/path/to/videtobeoverlayed.mp4" -o "output.mov"
|
||||
```
|
||||
|
||||
### remove background from video and make transparent gif
|
||||
|
||||
|
||||
```bash
|
||||
backgroundremover -i "/path/to/video.mp4" -tg -o "output.gif"
|
||||
```
|
||||
### Make matte key file (green screen overlay)
|
||||
|
||||
Make a matte file for premier
|
||||
|
||||
```bash
|
||||
backgroundremover -i "/path/to/video.mp4" -mk -o "output.matte.mp4"
|
||||
```
|
||||
|
||||
### Advance usage for video
|
||||
|
||||
Change the framerate of the video (default is set to 30)
|
||||
|
||||
```bash
|
||||
backgroundremover -i "/path/to/video.mp4" -fr 30 -tv -o "output.mov"
|
||||
```
|
||||
|
||||
Change the gpu batch size of the video (default is set to 1)
|
||||
|
||||
```bash
|
||||
backgroundremover -i "/path/to/video.mp4" -gp 4 -tv -o "output.mov"
|
||||
```
|
||||
|
||||
Change the number of workers working on video (default is set to 1)
|
||||
|
||||
```bash
|
||||
backgroundremover -i "/path/to/video.mp4" -wn 4 -tv -o "output.mov"
|
||||
```
|
||||
change the model for diferent background removal methods between `u2netp`, `u2net`, or `u2net_human_seg`
|
||||
```bash
|
||||
backgroundremover -i "/path/to/video.mp4" -m "u2net_human_seg"-tv -o "output.mov"
|
||||
```
|
||||
|
||||
## Todo
|
||||
|
||||
- convert logic from video to image to utilize more GPU on image removal
|
||||
- remove duplicate imports from image and video of u2net models
|
||||
- clean up documentation a bit more
|
||||
- add ability to adjust and give feedback images or videos to datasets
|
||||
- other
|
||||
|
||||
### Pull requests
|
||||
|
||||
Accepted
|
||||
|
||||
### If you like this library
|
||||
|
||||
Give a link to our project [BackgroundRemover.app](https://backgroundremover.app) or this git, telling people that you like it or use it.
|
||||
|
||||
### Reason for project
|
||||
|
||||
We made it our own package after merging together parts of others, adding in a few features of our own via posting parts as bounty questions on superuser, etc. As well as asked on hackernews earlier to open source the image part, so decided to add in video, and a bit more.
|
||||
|
||||
|
||||
|
||||
### References
|
||||
|
||||
- https://arxiv.org/pdf/2005.09007.pdf
|
||||
- https://github.com/NathanUA/U-2-Net
|
||||
- https://github.com/pymatting/pymatting
|
||||
- https://github.com/danielgatis/rembg
|
||||
- https://github.com/ecsplendid/rembg-greenscreen
|
||||
- https://superuser.com/questions/1647590/have-ffmpeg-merge-a-matte-key-file-over-the-normal-video-file-removing-the-backg
|
||||
- https://superuser.com/questions/1648680/ffmpeg-alphamerge-two-videos-into-a-gif-with-transparent-background/1649339?noredirect=1#comment2522687_1649339
|
||||
- https://superuser.com/questions/1649817/ffmpeg-overlay-a-video-after-alphamerging-two-others/1649856#1649856
|
||||
|
||||
### License
|
||||
|
||||
- Copyright (c) 2021-present [Johnathan Nader](https://github.com/nadermx)
|
||||
- Copyright (c) 2020-present [Lucas Nestler](https://github.com/ClashLuke)
|
||||
- Copyright (c) 2020-present [Dr. Tim Scarfe](https://github.com/ecsplendid)
|
||||
- Copyright (c) 2020-present [Daniel Gatis](https://github.com/danielgatis)
|
||||
|
||||
Licensed under [MIT License](./LICENSE.txt)
|
||||
|
5
pyproject.toml
Normal file
5
pyproject.toml
Normal file
@ -0,0 +1,5 @@
|
||||
[build-system]
|
||||
# These are the assumed default build requirements from pip:
|
||||
# https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support
|
||||
requires = ["setuptools>=40.8.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
16
requirements.txt
Normal file
16
requirements.txt
Normal file
@ -0,0 +1,16 @@
|
||||
numpy>=1.19.4
|
||||
scikit-image>=0.17.2
|
||||
torch>=1.7.0
|
||||
torchvision>=0.8.1
|
||||
waitress>=1.4.4
|
||||
tqdm>=4.51.0
|
||||
requests>=2.24.0
|
||||
scipy>=1.5.4
|
||||
pymatting>=1.1.1
|
||||
filetype>=1.0.7
|
||||
hsh>=1.1.0
|
||||
more_itertools==8.7.0
|
||||
moviepy==1.0.3
|
||||
Pillow==8.1.1
|
||||
ffmpeg-python
|
||||
|
4
setup.cfg
Normal file
4
setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[metadata]
|
||||
# This includes the license file(s) in the wheel.
|
||||
# https://wheel.readthedocs.io/en/stable/user_guide.html#including-license-files-in-the-generated-wheel-file
|
||||
license_files = LICENSE.txt
|
35
setup.py
Normal file
35
setup.py
Normal file
@ -0,0 +1,35 @@
|
||||
import pathlib
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
here = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
long_description = (here / "README.md").read_text(encoding="utf-8")
|
||||
|
||||
with open("requirements.txt") as f:
|
||||
requireds = f.read().splitlines()
|
||||
|
||||
setup(
|
||||
name="backgroundremover",
|
||||
version="0.1.1",
|
||||
description="Background remover from image and video",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/nadermx/backgroundremover",
|
||||
author="Johnathan Nader",
|
||||
author_email="john@nader.mx",
|
||||
classifiers=[
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
],
|
||||
keywords="remove, background, u2net, remove background, background remover",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages(where="src"),
|
||||
python_requires=">=3.6, <4",
|
||||
install_requires=requireds,
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"backgroundremover=backgroundremover.cmd.cli:main",
|
||||
],
|
||||
},
|
||||
)
|
9
src/backgroundremover/__init__.py
Normal file
9
src/backgroundremover/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
"""
|
||||
backgroundremover
|
||||
|
||||
A library to remove background from videos and images
|
||||
"""
|
||||
|
||||
__version__ = "0.1.1"
|
||||
__author__ = 'Johnathan Nader'
|
||||
__credits__ = 'BackgroundRemover.app'
|
201
src/backgroundremover/bg.py
Normal file
201
src/backgroundremover/bg.py
Normal file
@ -0,0 +1,201 @@
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
import typing
|
||||
from PIL import Image
|
||||
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
||||
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
||||
from pymatting.util.util import stack_images
|
||||
from scipy.ndimage.morphology import binary_erosion
|
||||
import moviepy.editor as mpy
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional
|
||||
import torch.nn.functional
|
||||
from hsh.library.hash import Hasher
|
||||
from tqdm import tqdm
|
||||
from .u2net import detect, u2net
|
||||
|
||||
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
def __init__(self, model_name):
|
||||
super(Net, self).__init__()
|
||||
hasher = Hasher()
|
||||
|
||||
model, hash_val, drive_target, env_var = {
|
||||
'u2netp': (u2net.U2NETP,
|
||||
'e4f636406ca4e2af789941e7f139ee2e',
|
||||
'1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy',
|
||||
'U2NET_PATH'),
|
||||
'u2net': (u2net.U2NET,
|
||||
'09fb4e49b7f785c9f855baf94916840a',
|
||||
'1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P',
|
||||
'U2NET_PATH'),
|
||||
'u2net_human_seg': (u2net.U2NET,
|
||||
'347c3d51b01528e5c6c071e3cff1cb55',
|
||||
'1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ',
|
||||
'U2NET_PATH')
|
||||
}[model_name]
|
||||
path = os.environ.get(env_var, os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")))
|
||||
net = model(3, 1)
|
||||
if not os.path.exists(path) or hasher.md5(path) != hash_val:
|
||||
head, tail = os.path.split(path)
|
||||
os.makedirs(head, exist_ok=True)
|
||||
|
||||
URL = "https://docs.google.com/uc?export=download"
|
||||
|
||||
session = requests.Session()
|
||||
response = session.get(URL, params={"id": drive_target}, stream=True)
|
||||
|
||||
token = None
|
||||
for key, value in response.cookies.items():
|
||||
if key.startswith("download_warning"):
|
||||
token = value
|
||||
break
|
||||
|
||||
if token:
|
||||
params = {"id": drive_target, "confirm": token}
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
|
||||
total = int(response.headers.get("content-length", 0))
|
||||
|
||||
with open(path, "wb") as file, tqdm(
|
||||
desc=f"Downloading {tail} to {head}",
|
||||
total=total,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
net.load_state_dict(torch.load(path, map_location=torch.device(DEVICE)))
|
||||
net.to(device=DEVICE, dtype=torch.float32, non_blocking=True)
|
||||
net.eval()
|
||||
self.net = net
|
||||
|
||||
def forward(self, block_input: torch.Tensor):
|
||||
image_data = block_input.permute(0, 3, 1, 2)
|
||||
original_shape = image_data.shape[2:]
|
||||
image_data = torch.nn.functional.interpolate(image_data, (320, 320), mode='bilinear')
|
||||
image_data = (image_data / 255 - 0.485) / 0.229
|
||||
out = self.net(image_data)[0][:, 0:1]
|
||||
ma = torch.max(out)
|
||||
mi = torch.min(out)
|
||||
out = (out - mi) / (ma - mi) * 255
|
||||
out = torch.nn.functional.interpolate(out, original_shape, mode='bilinear')
|
||||
out = out[:, 0]
|
||||
out = out.to(dtype=torch.uint8, device=torch.device('cpu'), non_blocking=True).detach()
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def alpha_matting_cutout(
|
||||
img,
|
||||
mask,
|
||||
foreground_threshold,
|
||||
background_threshold,
|
||||
erode_structure_size,
|
||||
base_size,
|
||||
):
|
||||
size = img.size
|
||||
|
||||
img.thumbnail((base_size, base_size), Image.LANCZOS)
|
||||
mask = mask.resize(img.size, Image.LANCZOS)
|
||||
|
||||
img = np.asarray(img)
|
||||
mask = np.asarray(mask)
|
||||
|
||||
# guess likely foreground/background
|
||||
is_foreground = mask > foreground_threshold
|
||||
is_background = mask < background_threshold
|
||||
|
||||
# erode foreground/background
|
||||
structure = None
|
||||
if erode_structure_size > 0:
|
||||
structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int)
|
||||
|
||||
is_foreground = binary_erosion(is_foreground, structure=structure)
|
||||
is_background = binary_erosion(is_background, structure=structure, border_value=1)
|
||||
|
||||
# build trimap
|
||||
# 0 = background
|
||||
# 128 = unknown
|
||||
# 255 = foreground
|
||||
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
|
||||
trimap[is_foreground] = 255
|
||||
trimap[is_background] = 0
|
||||
|
||||
# build the cutout image
|
||||
img_normalized = img / 255.0
|
||||
trimap_normalized = trimap / 255.0
|
||||
|
||||
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
|
||||
foreground = estimate_foreground_ml(img_normalized, alpha)
|
||||
cutout = stack_images(foreground, alpha)
|
||||
|
||||
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
|
||||
cutout = Image.fromarray(cutout)
|
||||
cutout = cutout.resize(size, Image.LANCZOS)
|
||||
|
||||
return cutout
|
||||
|
||||
|
||||
def naive_cutout(img, mask):
|
||||
empty = Image.new("RGBA", (img.size), 0)
|
||||
cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
|
||||
return cutout
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_model(model_name):
|
||||
if model_name == "u2netp":
|
||||
return detect.load_model(model_name="u2netp")
|
||||
if model_name == "u2net_human_seg":
|
||||
return detect.load_model(model_name="u2net_human_seg")
|
||||
else:
|
||||
return detect.load_model(model_name="u2net")
|
||||
|
||||
|
||||
def remove(
|
||||
data,
|
||||
model_name="u2net",
|
||||
alpha_matting=False,
|
||||
alpha_matting_foreground_threshold=240,
|
||||
alpha_matting_background_threshold=10,
|
||||
alpha_matting_erode_structure_size=10,
|
||||
alpha_matting_base_size=1000,
|
||||
):
|
||||
model = get_model(model_name)
|
||||
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||
mask = detect.predict(model, np.array(img)).convert("L")
|
||||
|
||||
if alpha_matting:
|
||||
cutout = alpha_matting_cutout(
|
||||
img,
|
||||
mask,
|
||||
alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold,
|
||||
alpha_matting_erode_structure_size,
|
||||
alpha_matting_base_size,
|
||||
)
|
||||
else:
|
||||
cutout = naive_cutout(img, mask)
|
||||
|
||||
bio = io.BytesIO()
|
||||
cutout.save(bio, "PNG")
|
||||
|
||||
return bio.getbuffer()
|
||||
|
||||
|
||||
def iter_frames(path):
|
||||
return mpy.VideoFileClip(path).resize(height=320).iter_frames(dtype="uint8")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def remove_many(image_data: typing.List[np.array], net: Net):
|
||||
image_data = np.stack(image_data)
|
||||
image_data = torch.as_tensor(image_data, dtype=torch.float32, device=DEVICE)
|
||||
return net(image_data).numpy()
|
0
src/backgroundremover/cmd/__init__.py
Normal file
0
src/backgroundremover/cmd/__init__.py
Normal file
264
src/backgroundremover/cmd/cli.py
Normal file
264
src/backgroundremover/cmd/cli.py
Normal file
@ -0,0 +1,264 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
|
||||
from ..bg import remove
|
||||
from ..utilities import matte_key, transparentgif, transparentvideo, transparentvideoovervideo, transparentvideooverimage, \
|
||||
transparentgifwithbackground
|
||||
import torch
|
||||
from .. import utilities
|
||||
|
||||
|
||||
def main():
|
||||
model_path = os.environ.get(
|
||||
"U2NETP_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net")),
|
||||
)
|
||||
model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
|
||||
if len(model_choices) == 0:
|
||||
model_choices = ["u2net", "u2netp", "u2net_human_seg"]
|
||||
|
||||
ap = argparse.ArgumentParser()
|
||||
|
||||
ap.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
default="u2net",
|
||||
type=str,
|
||||
choices=model_choices,
|
||||
help="The model name, u2net, u2netp, u2net_human_seg",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-a",
|
||||
"--alpha-matting",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
type=lambda x: bool(strtobool(x)),
|
||||
help="When true use alpha matting cutout.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-af",
|
||||
"--alpha-matting-foreground-threshold",
|
||||
default=240,
|
||||
type=int,
|
||||
help="The trimap foreground threshold.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-ab",
|
||||
"--alpha-matting-background-threshold",
|
||||
default=10,
|
||||
type=int,
|
||||
help="The trimap background threshold.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-ae",
|
||||
"--alpha-matting-erode-size",
|
||||
default=10,
|
||||
type=int,
|
||||
help="Size of element used for the erosion.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-az",
|
||||
"--alpha-matting-base-size",
|
||||
default=1000,
|
||||
type=int,
|
||||
help="The image base size.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-wn",
|
||||
"--workernodes",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Number of parallel workers"
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-gb",
|
||||
"--gpubatchsize",
|
||||
default=2,
|
||||
type=int,
|
||||
help="GPU batchsize"
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-fr",
|
||||
"--framerate",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="Override the frame rate"
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-fl",
|
||||
"--framelimit",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="Limit the number of frames to process for quick testing.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-mk",
|
||||
"--mattekey",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
type=lambda x: bool(strtobool(x)),
|
||||
help="Output the Matte key file",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-tv",
|
||||
"--transparentvideo",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
type=lambda x: bool(strtobool(x)),
|
||||
help="Output transparent video format mov",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-tov",
|
||||
"--transparentvideoovervideo",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
type=lambda x: bool(strtobool(x)),
|
||||
help="Overlay transparent video over another video",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-toi",
|
||||
"--transparentvideooverimage",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
type=lambda x: bool(strtobool(x)),
|
||||
help="Overlay transparent video over another video",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-tg",
|
||||
"--transparentgif",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
type=lambda x: bool(strtobool(x)),
|
||||
help="Make transparent gif from video",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-tgwb",
|
||||
"--transparentgifwithbackground",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
type=lambda x: bool(strtobool(x)),
|
||||
help="Make transparent background overlay a background image",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-i",
|
||||
"--input",
|
||||
nargs="?",
|
||||
default="-",
|
||||
type=argparse.FileType("rb"),
|
||||
help="Path to the input video or image.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-bi",
|
||||
"--backgroundimage",
|
||||
nargs="?",
|
||||
default="-",
|
||||
type=argparse.FileType("rb"),
|
||||
help="Path to background image.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-bv",
|
||||
"--backgroundvideo",
|
||||
nargs="?",
|
||||
default="-",
|
||||
type=argparse.FileType("rb"),
|
||||
help="Path to background video.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
nargs="?",
|
||||
default="-",
|
||||
type=argparse.FileType("wb"),
|
||||
help="Path to the output",
|
||||
)
|
||||
|
||||
args = ap.parse_args()
|
||||
if args.input.name.rsplit('.', 1)[1] in ['mp4', 'mov', 'webm', 'ogg', 'gif']:
|
||||
if args.mattekey:
|
||||
matte_key(os.path.abspath(args.output.name), os.path.abspath(args.input.name),
|
||||
worker_nodes=args.workernodes,
|
||||
gpu_batchsize=args.gpubatchsize,
|
||||
model_name=args.model,
|
||||
frame_limit=args.framelimit,
|
||||
framerate=args.framerate)
|
||||
elif args.transparentvideo:
|
||||
transparentvideo(os.path.abspath(args.output.name), os.path.abspath(args.input.name),
|
||||
worker_nodes=args.workernodes,
|
||||
gpu_batchsize=args.gpubatchsize,
|
||||
model_name=args.model,
|
||||
frame_limit=args.framelimit,
|
||||
framerate=args.framerate)
|
||||
elif args.transparentvideoovervideo:
|
||||
transparentvideoovervideo(os.path.abspath(args.output.name), os.path.abspath(args.backgroundvideo.name),
|
||||
os.path.abspath(args.input.name),
|
||||
worker_nodes=args.workernodes,
|
||||
gpu_batchsize=args.gpubatchsize,
|
||||
model_name=args.model,
|
||||
frame_limit=args.framelimit,
|
||||
framerate=args.framerate)
|
||||
elif args.transparentvideooverimage:
|
||||
transparentvideooverimage(os.path.abspath(args.output.name), os.path.abspath(args.backgroundimage.name),
|
||||
os.path.abspath(args.input.name),
|
||||
worker_nodes=args.workernodes,
|
||||
gpu_batchsize=args.gpubatchsize,
|
||||
model_name=args.model,
|
||||
frame_limit=args.framelimit,
|
||||
framerate=args.framerate)
|
||||
elif args.transparentgif:
|
||||
transparentgif(os.path.abspath(args.output.name), os.path.abspath(args.input.name),
|
||||
worker_nodes=args.workernodes,
|
||||
gpu_batchsize=args.gpubatchsize,
|
||||
model_name=args.model,
|
||||
frame_limit=args.framelimit,
|
||||
framerate=args.framerate)
|
||||
elif args.transparentgifwithbackground:
|
||||
transparentgifwithbackground(os.path.abspath(args.output.name), os.path.abspath(args.backgroundimage.name), os.path.abspath(args.input.name),
|
||||
worker_nodes=args.workernodes,
|
||||
gpu_batchsize=args.gpubatchsize,
|
||||
model_name=args.model,
|
||||
frame_limit=args.framelimit,
|
||||
framerate=args.framerate)
|
||||
|
||||
else:
|
||||
print(args.output.name)
|
||||
r = lambda i: i.buffer.read() if hasattr(i, "buffer") else i.read()
|
||||
w = lambda o, data: o.buffer.write(data) if hasattr(o, "buffer") else o.write(data)
|
||||
w(
|
||||
args.output,
|
||||
remove(
|
||||
r(args.input),
|
||||
model_name=args.model,
|
||||
alpha_matting=args.alpha_matting,
|
||||
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
|
||||
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
|
||||
alpha_matting_base_size=args.alpha_matting_base_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.multiprocessing.set_start_method('spawn')
|
||||
main()
|
98
src/backgroundremover/cmd/server.py
Normal file
98
src/backgroundremover/cmd/server.py
Normal file
@ -0,0 +1,98 @@
|
||||
import os
|
||||
import glob
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
from urllib.parse import unquote_plus
|
||||
from urllib.request import urlopen
|
||||
|
||||
from flask import Flask, request, send_file
|
||||
from waitress import serve
|
||||
|
||||
from ..bg import remove
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/", methods=["GET", "POST"])
|
||||
def index():
|
||||
file_content = ""
|
||||
|
||||
if request.method == "POST":
|
||||
if "file" not in request.files:
|
||||
return {"error": "missing post form param 'file'"}, 400
|
||||
|
||||
file_content = request.files["file"].read()
|
||||
|
||||
if request.method == "GET":
|
||||
url = request.args.get("url", type=str)
|
||||
if url is None:
|
||||
return {"error": "missing query param 'url'"}, 400
|
||||
|
||||
file_content = urlopen(unquote_plus(url)).read()
|
||||
|
||||
if file_content == "":
|
||||
return {"error": "File content is empty"}, 400
|
||||
|
||||
alpha_matting = "a" in request.values
|
||||
af = request.values.get("af", type=int, default=240)
|
||||
ab = request.values.get("ab", type=int, default=10)
|
||||
ae = request.values.get("ae", type=int, default=10)
|
||||
az = request.values.get("az", type=int, default=1000)
|
||||
|
||||
model = request.args.get("model", type=str, default="u2net")
|
||||
model_path = os.environ.get(
|
||||
"U2NETP_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net")),
|
||||
)
|
||||
model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
|
||||
if len(model_choices) == 0:
|
||||
model_choices = ["u2net", "u2netp", "u2net_human_seg"]
|
||||
|
||||
if model not in model_choices:
|
||||
return {"error": f"invalid query param 'model'. Available options are {model_choices}"}, 400
|
||||
|
||||
try:
|
||||
return send_file(
|
||||
BytesIO(
|
||||
remove(
|
||||
file_content,
|
||||
model_name=model,
|
||||
alpha_matting=alpha_matting,
|
||||
alpha_matting_foreground_threshold=af,
|
||||
alpha_matting_background_threshold=ab,
|
||||
alpha_matting_erode_structure_size=ae,
|
||||
alpha_matting_base_size=az,
|
||||
)
|
||||
),
|
||||
mimetype="image/png",
|
||||
)
|
||||
except Exception as e:
|
||||
app.logger.exception(e, exc_info=True)
|
||||
return {"error": "oops, something went wrong!"}, 500
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
|
||||
ap.add_argument(
|
||||
"-a",
|
||||
"--addr",
|
||||
default="0.0.0.0",
|
||||
type=str,
|
||||
help="The IP address to bind to.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-p",
|
||||
"--port",
|
||||
default=5000,
|
||||
type=int,
|
||||
help="The port to bind to.",
|
||||
)
|
||||
|
||||
args = ap.parse_args()
|
||||
serve(app, host=args.addr, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
0
src/backgroundremover/u2net/__init__.py
Normal file
0
src/backgroundremover/u2net/__init__.py
Normal file
324
src/backgroundremover/u2net/data_loader.py
Normal file
324
src/backgroundremover/u2net/data_loader.py
Normal file
@ -0,0 +1,324 @@
|
||||
# data loader
|
||||
from __future__ import division, print_function
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from skimage import color, io, transform
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
# ==========================dataset load==========================
|
||||
class RescaleT(object):
|
||||
def __init__(self, output_size):
|
||||
assert isinstance(output_size, (int, tuple))
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self, sample):
|
||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if isinstance(self.output_size, int):
|
||||
if h > w:
|
||||
new_h, new_w = self.output_size * h / w, self.output_size
|
||||
else:
|
||||
new_h, new_w = self.output_size, self.output_size * w / h
|
||||
else:
|
||||
new_h, new_w = self.output_size
|
||||
|
||||
new_h, new_w = int(new_h), int(new_w)
|
||||
|
||||
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
|
||||
# img = transform.resize(image,(new_h,new_w),mode='constant')
|
||||
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
|
||||
|
||||
img = transform.resize(
|
||||
image, (self.output_size, self.output_size), mode="constant"
|
||||
)
|
||||
lbl = transform.resize(
|
||||
label,
|
||||
(self.output_size, self.output_size),
|
||||
mode="constant",
|
||||
order=0,
|
||||
preserve_range=True,
|
||||
)
|
||||
|
||||
return {"imidx": imidx, "image": img, "label": lbl}
|
||||
|
||||
|
||||
class Rescale(object):
|
||||
def __init__(self, output_size):
|
||||
assert isinstance(output_size, (int, tuple))
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self, sample):
|
||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||
|
||||
if random.random() >= 0.5:
|
||||
image = image[::-1]
|
||||
label = label[::-1]
|
||||
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if isinstance(self.output_size, int):
|
||||
if h > w:
|
||||
new_h, new_w = self.output_size * h / w, self.output_size
|
||||
else:
|
||||
new_h, new_w = self.output_size, self.output_size * w / h
|
||||
else:
|
||||
new_h, new_w = self.output_size
|
||||
|
||||
new_h, new_w = int(new_h), int(new_w)
|
||||
|
||||
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
|
||||
img = transform.resize(image, (new_h, new_w), mode="constant")
|
||||
lbl = transform.resize(
|
||||
label, (new_h, new_w), mode="constant", order=0, preserve_range=True
|
||||
)
|
||||
|
||||
return {"imidx": imidx, "image": img, "label": lbl}
|
||||
|
||||
|
||||
class RandomCrop(object):
|
||||
def __init__(self, output_size):
|
||||
assert isinstance(output_size, (int, tuple))
|
||||
if isinstance(output_size, int):
|
||||
self.output_size = (output_size, output_size)
|
||||
else:
|
||||
assert len(output_size) == 2
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self, sample):
|
||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||
|
||||
if random.random() >= 0.5:
|
||||
image = image[::-1]
|
||||
label = label[::-1]
|
||||
|
||||
h, w = image.shape[:2]
|
||||
new_h, new_w = self.output_size
|
||||
|
||||
top = np.random.randint(0, h - new_h)
|
||||
left = np.random.randint(0, w - new_w)
|
||||
|
||||
image = image[top : top + new_h, left : left + new_w]
|
||||
label = label[top : top + new_h, left : left + new_w]
|
||||
|
||||
return {"imidx": imidx, "image": image, "label": label}
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
"""Convert ndarrays in sample to Tensors."""
|
||||
|
||||
def __call__(self, sample):
|
||||
|
||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||
|
||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
||||
tmpLbl = np.zeros(label.shape)
|
||||
|
||||
image = image / np.max(image)
|
||||
if np.max(label) < 1e-6:
|
||||
label = label
|
||||
else:
|
||||
label = label / np.max(label)
|
||||
|
||||
if image.shape[2] == 1:
|
||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
|
||||
else:
|
||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
|
||||
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
|
||||
|
||||
tmpLbl[:, :, 0] = label[:, :, 0]
|
||||
|
||||
# change the r,g,b to b,r,g from [0,255] to [0,1]
|
||||
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
|
||||
tmpImg = tmpImg.transpose((2, 0, 1))
|
||||
tmpLbl = label.transpose((2, 0, 1))
|
||||
|
||||
return {
|
||||
"imidx": torch.from_numpy(imidx),
|
||||
"image": torch.from_numpy(tmpImg),
|
||||
"label": torch.from_numpy(tmpLbl),
|
||||
}
|
||||
|
||||
|
||||
class ToTensorLab(object):
|
||||
"""Convert ndarrays in sample to Tensors."""
|
||||
|
||||
def __init__(self, flag=0):
|
||||
self.flag = flag
|
||||
|
||||
def __call__(self, sample):
|
||||
|
||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||
|
||||
tmpLbl = np.zeros(label.shape)
|
||||
|
||||
if np.max(label) < 1e-6:
|
||||
label = label
|
||||
else:
|
||||
label = label / np.max(label)
|
||||
|
||||
# change the color space
|
||||
if self.flag == 2: # with rgb and Lab colors
|
||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
|
||||
tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
|
||||
if image.shape[2] == 1:
|
||||
tmpImgt[:, :, 0] = image[:, :, 0]
|
||||
tmpImgt[:, :, 1] = image[:, :, 0]
|
||||
tmpImgt[:, :, 2] = image[:, :, 0]
|
||||
else:
|
||||
tmpImgt = image
|
||||
tmpImgtl = color.rgb2lab(tmpImgt)
|
||||
|
||||
# nomalize image to range [0,1]
|
||||
tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
|
||||
np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])
|
||||
)
|
||||
tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
|
||||
np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])
|
||||
)
|
||||
tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
|
||||
np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])
|
||||
)
|
||||
tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
|
||||
np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])
|
||||
)
|
||||
tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
|
||||
np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])
|
||||
)
|
||||
tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
|
||||
np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])
|
||||
)
|
||||
|
||||
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
|
||||
|
||||
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
|
||||
tmpImg[:, :, 0]
|
||||
)
|
||||
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
|
||||
tmpImg[:, :, 1]
|
||||
)
|
||||
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
|
||||
tmpImg[:, :, 2]
|
||||
)
|
||||
tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(
|
||||
tmpImg[:, :, 3]
|
||||
)
|
||||
tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(
|
||||
tmpImg[:, :, 4]
|
||||
)
|
||||
tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(
|
||||
tmpImg[:, :, 5]
|
||||
)
|
||||
|
||||
elif self.flag == 1: # with Lab color
|
||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
||||
|
||||
if image.shape[2] == 1:
|
||||
tmpImg[:, :, 0] = image[:, :, 0]
|
||||
tmpImg[:, :, 1] = image[:, :, 0]
|
||||
tmpImg[:, :, 2] = image[:, :, 0]
|
||||
else:
|
||||
tmpImg = image
|
||||
|
||||
tmpImg = color.rgb2lab(tmpImg)
|
||||
|
||||
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
|
||||
|
||||
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
|
||||
np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])
|
||||
)
|
||||
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
|
||||
np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])
|
||||
)
|
||||
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
|
||||
np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])
|
||||
)
|
||||
|
||||
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
|
||||
tmpImg[:, :, 0]
|
||||
)
|
||||
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
|
||||
tmpImg[:, :, 1]
|
||||
)
|
||||
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
|
||||
tmpImg[:, :, 2]
|
||||
)
|
||||
|
||||
else: # with rgb color
|
||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
||||
image = image / np.max(image)
|
||||
if image.shape[2] == 1:
|
||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
|
||||
else:
|
||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
|
||||
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
|
||||
|
||||
tmpLbl[:, :, 0] = label[:, :, 0]
|
||||
|
||||
# change the r,g,b to b,r,g from [0,255] to [0,1]
|
||||
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
|
||||
tmpImg = tmpImg.transpose((2, 0, 1))
|
||||
tmpLbl = label.transpose((2, 0, 1))
|
||||
|
||||
return {
|
||||
"imidx": torch.from_numpy(imidx),
|
||||
"image": torch.from_numpy(tmpImg),
|
||||
"label": torch.from_numpy(tmpLbl),
|
||||
}
|
||||
|
||||
|
||||
class SalObjDataset(Dataset):
|
||||
def __init__(self, img_name_list, lbl_name_list, transform=None):
|
||||
# self.root_dir = root_dir
|
||||
# self.image_name_list = glob.glob(image_dir+'*.png')
|
||||
# self.label_name_list = glob.glob(label_dir+'*.png')
|
||||
self.image_name_list = img_name_list
|
||||
self.label_name_list = lbl_name_list
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_name_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
|
||||
# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
|
||||
|
||||
image = io.imread(self.image_name_list[idx])
|
||||
imname = self.image_name_list[idx]
|
||||
imidx = np.array([idx])
|
||||
|
||||
if 0 == len(self.label_name_list):
|
||||
label_3 = np.zeros(image.shape)
|
||||
else:
|
||||
label_3 = io.imread(self.label_name_list[idx])
|
||||
|
||||
label = np.zeros(label_3.shape[0:2])
|
||||
if 3 == len(label_3.shape):
|
||||
label = label_3[:, :, 0]
|
||||
elif 2 == len(label_3.shape):
|
||||
label = label_3
|
||||
|
||||
if 3 == len(image.shape) and 2 == len(label.shape):
|
||||
label = label[:, :, np.newaxis]
|
||||
elif 2 == len(image.shape) and 2 == len(label.shape):
|
||||
image = image[:, :, np.newaxis]
|
||||
label = label[:, :, np.newaxis]
|
||||
|
||||
sample = {"imidx": imidx, "image": image, "label": label}
|
||||
|
||||
if self.transform:
|
||||
sample = self.transform(sample)
|
||||
|
||||
return sample
|
178
src/backgroundremover/u2net/detect.py
Normal file
178
src/backgroundremover/u2net/detect.py
Normal file
@ -0,0 +1,178 @@
|
||||
import errno
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from hsh.library.hash import Hasher
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import data_loader, u2net
|
||||
|
||||
|
||||
def download_file_from_google_drive(id, fname, destination):
|
||||
head, tail = os.path.split(destination)
|
||||
os.makedirs(head, exist_ok=True)
|
||||
|
||||
URL = "https://docs.google.com/uc?export=download"
|
||||
|
||||
session = requests.Session()
|
||||
response = session.get(URL, params={"id": id}, stream=True)
|
||||
|
||||
token = None
|
||||
for key, value in response.cookies.items():
|
||||
if key.startswith("download_warning"):
|
||||
token = value
|
||||
break
|
||||
|
||||
if token:
|
||||
params = {"id": id, "confirm": token}
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
|
||||
total = int(response.headers.get("content-length", 0))
|
||||
|
||||
with open(destination, "wb") as file, tqdm(
|
||||
desc=f"Downloading {tail} to {head}",
|
||||
total=total,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
|
||||
|
||||
def load_model(model_name: str = "u2net"):
|
||||
hasher = Hasher()
|
||||
|
||||
if model_name == "u2netp":
|
||||
net = u2net.U2NETP(3, 1)
|
||||
path = os.environ.get(
|
||||
"U2NETP_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||
)
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"
|
||||
):
|
||||
download_file_from_google_drive(
|
||||
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
|
||||
"u2netp.pth",
|
||||
path,
|
||||
)
|
||||
|
||||
elif model_name == "u2net":
|
||||
net = u2net.U2NET(3, 1)
|
||||
path = os.environ.get(
|
||||
"U2NET_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||
)
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
|
||||
):
|
||||
download_file_from_google_drive(
|
||||
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
|
||||
"u2net.pth",
|
||||
path,
|
||||
)
|
||||
|
||||
elif model_name == "u2net_human_seg":
|
||||
net = u2net.U2NET(3, 1)
|
||||
path = os.environ.get(
|
||||
"U2NET_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||
)
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
|
||||
):
|
||||
download_file_from_google_drive(
|
||||
"1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
|
||||
"u2net_human_seg.pth",
|
||||
path,
|
||||
)
|
||||
else:
|
||||
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
net.load_state_dict(torch.load(path))
|
||||
net.to(torch.device("cuda"))
|
||||
else:
|
||||
net.load_state_dict(
|
||||
torch.load(
|
||||
path,
|
||||
map_location="cpu",
|
||||
)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(
|
||||
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
|
||||
)
|
||||
|
||||
net.eval()
|
||||
|
||||
return net
|
||||
|
||||
|
||||
def norm_pred(d):
|
||||
ma = torch.max(d)
|
||||
mi = torch.min(d)
|
||||
dn = (d - mi) / (ma - mi)
|
||||
|
||||
return dn
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
label_3 = np.zeros(image.shape)
|
||||
label = np.zeros(label_3.shape[0:2])
|
||||
|
||||
if 3 == len(label_3.shape):
|
||||
label = label_3[:, :, 0]
|
||||
elif 2 == len(label_3.shape):
|
||||
label = label_3
|
||||
|
||||
if 3 == len(image.shape) and 2 == len(label.shape):
|
||||
label = label[:, :, np.newaxis]
|
||||
elif 2 == len(image.shape) and 2 == len(label.shape):
|
||||
image = image[:, :, np.newaxis]
|
||||
label = label[:, :, np.newaxis]
|
||||
|
||||
transform = transforms.Compose(
|
||||
[data_loader.RescaleT(320), data_loader.ToTensorLab(flag=0)]
|
||||
)
|
||||
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
def predict(net, item):
|
||||
|
||||
sample = preprocess(item)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
inputs_test = torch.cuda.FloatTensor(
|
||||
sample["image"].unsqueeze(0).cuda().float()
|
||||
)
|
||||
else:
|
||||
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
|
||||
|
||||
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
|
||||
|
||||
pred = d1[:, 0, :, :]
|
||||
predict = norm_pred(pred)
|
||||
|
||||
predict = predict.squeeze()
|
||||
predict_np = predict.cpu().detach().numpy()
|
||||
img = Image.fromarray(predict_np * 255).convert("RGB")
|
||||
|
||||
del d1, d2, d3, d4, d5, d6, d7, pred, predict, predict_np, inputs_test, sample
|
||||
|
||||
return img
|
541
src/backgroundremover/u2net/u2net.py
Normal file
541
src/backgroundremover/u2net/u2net.py
Normal file
@ -0,0 +1,541 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import models
|
||||
|
||||
|
||||
class REBNCONV(nn.Module):
|
||||
def __init__(self, in_ch=3, out_ch=3, dirate=1):
|
||||
super(REBNCONV, self).__init__()
|
||||
|
||||
self.conv_s1 = nn.Conv2d(
|
||||
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
|
||||
)
|
||||
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
||||
self.relu_s1 = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
||||
|
||||
return xout
|
||||
|
||||
|
||||
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
||||
def _upsample_like(src, tar):
|
||||
|
||||
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
|
||||
|
||||
return src
|
||||
|
||||
|
||||
### RSU-7 ###
|
||||
class RSU7(nn.Module): # UNet07DRES(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU7, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
hx = self.pool5(hx5)
|
||||
|
||||
hx6 = self.rebnconv6(hx)
|
||||
|
||||
hx7 = self.rebnconv7(hx6)
|
||||
|
||||
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
||||
hx6dup = _upsample_like(hx6d, hx5)
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-6 ###
|
||||
class RSU6(nn.Module): # UNet06DRES(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU6, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
|
||||
hx6 = self.rebnconv6(hx5)
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-5 ###
|
||||
class RSU5(nn.Module): # UNet05DRES(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU5, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
|
||||
hx5 = self.rebnconv5(hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-4 ###
|
||||
class RSU4(nn.Module): # UNet04DRES(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-4F ###
|
||||
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4F, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx2 = self.rebnconv2(hx1)
|
||||
hx3 = self.rebnconv3(hx2)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
##### U^2-Net ####
|
||||
class U2NET(nn.Module):
|
||||
def __init__(self, in_ch=3, out_ch=1):
|
||||
super(U2NET, self).__init__()
|
||||
|
||||
self.stage1 = RSU7(in_ch, 32, 64)
|
||||
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage2 = RSU6(64, 32, 128)
|
||||
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage3 = RSU5(128, 64, 256)
|
||||
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage4 = RSU4(256, 128, 512)
|
||||
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage5 = RSU4F(512, 256, 512)
|
||||
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage6 = RSU4F(512, 256, 512)
|
||||
|
||||
# decoder
|
||||
self.stage5d = RSU4F(1024, 256, 512)
|
||||
self.stage4d = RSU4(1024, 128, 256)
|
||||
self.stage3d = RSU5(512, 64, 128)
|
||||
self.stage2d = RSU6(256, 32, 64)
|
||||
self.stage1d = RSU7(128, 16, 64)
|
||||
|
||||
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
||||
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
||||
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||
|
||||
self.outconv = nn.Conv2d(6, out_ch, 1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
|
||||
# stage 1
|
||||
hx1 = self.stage1(hx)
|
||||
hx = self.pool12(hx1)
|
||||
|
||||
# stage 2
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
# stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
|
||||
# stage 4
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
|
||||
# stage 5
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
|
||||
# stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = _upsample_like(hx6, hx5)
|
||||
|
||||
# -------------------- decoder --------------------
|
||||
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
# side output
|
||||
d1 = self.side1(hx1d)
|
||||
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = _upsample_like(d2, d1)
|
||||
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = _upsample_like(d3, d1)
|
||||
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = _upsample_like(d4, d1)
|
||||
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = _upsample_like(d5, d1)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6, d1)
|
||||
|
||||
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
||||
|
||||
return (
|
||||
torch.sigmoid(d0),
|
||||
torch.sigmoid(d1),
|
||||
torch.sigmoid(d2),
|
||||
torch.sigmoid(d3),
|
||||
torch.sigmoid(d4),
|
||||
torch.sigmoid(d5),
|
||||
torch.sigmoid(d6),
|
||||
)
|
||||
|
||||
|
||||
### U^2-Net small ###
|
||||
class U2NETP(nn.Module):
|
||||
def __init__(self, in_ch=3, out_ch=1):
|
||||
super(U2NETP, self).__init__()
|
||||
|
||||
self.stage1 = RSU7(in_ch, 16, 64)
|
||||
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage2 = RSU6(64, 16, 64)
|
||||
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage3 = RSU5(64, 16, 64)
|
||||
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage4 = RSU4(64, 16, 64)
|
||||
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage5 = RSU4F(64, 16, 64)
|
||||
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage6 = RSU4F(64, 16, 64)
|
||||
|
||||
# decoder
|
||||
self.stage5d = RSU4F(128, 16, 64)
|
||||
self.stage4d = RSU4(128, 16, 64)
|
||||
self.stage3d = RSU5(128, 16, 64)
|
||||
self.stage2d = RSU6(128, 16, 64)
|
||||
self.stage1d = RSU7(128, 16, 64)
|
||||
|
||||
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
|
||||
self.outconv = nn.Conv2d(6, out_ch, 1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
|
||||
# stage 1
|
||||
hx1 = self.stage1(hx)
|
||||
hx = self.pool12(hx1)
|
||||
|
||||
# stage 2
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
# stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
|
||||
# stage 4
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
|
||||
# stage 5
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
|
||||
# stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = _upsample_like(hx6, hx5)
|
||||
|
||||
# decoder
|
||||
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
# side output
|
||||
d1 = self.side1(hx1d)
|
||||
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = _upsample_like(d2, d1)
|
||||
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = _upsample_like(d3, d1)
|
||||
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = _upsample_like(d4, d1)
|
||||
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = _upsample_like(d5, d1)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6, d1)
|
||||
|
||||
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
||||
|
||||
return (
|
||||
torch.sigmoid(d0),
|
||||
torch.sigmoid(d1),
|
||||
torch.sigmoid(d2),
|
||||
torch.sigmoid(d3),
|
||||
torch.sigmoid(d4),
|
||||
torch.sigmoid(d5),
|
||||
torch.sigmoid(d6),
|
||||
)
|
285
src/backgroundremover/utilities.py
Normal file
285
src/backgroundremover/utilities.py
Normal file
@ -0,0 +1,285 @@
|
||||
import os
|
||||
import math
|
||||
import multiprocessing
|
||||
import subprocess as sp
|
||||
import time
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import torch
|
||||
import tempfile
|
||||
from .bg import DEVICE, Net, iter_frames, remove_many
|
||||
import shlex
|
||||
|
||||
def worker(worker_nodes,
|
||||
worker_index,
|
||||
result_dict,
|
||||
model_name,
|
||||
gpu_batchsize,
|
||||
total_frames,
|
||||
frames_dict):
|
||||
print(F"WORKER {worker_index} ONLINE")
|
||||
|
||||
output_index = worker_index + 1
|
||||
base_index = worker_index * gpu_batchsize
|
||||
net = Net(model_name)
|
||||
script_net = None
|
||||
for fi in (list(range(base_index + i * worker_nodes * gpu_batchsize,
|
||||
min(base_index + i * worker_nodes * gpu_batchsize + gpu_batchsize, total_frames)))
|
||||
for i in range(math.ceil(total_frames / worker_nodes / gpu_batchsize))):
|
||||
if not fi:
|
||||
break
|
||||
|
||||
# are we processing frames faster than the frame ripper is saving them?
|
||||
last = fi[-1]
|
||||
while last not in frames_dict:
|
||||
time.sleep(0.1)
|
||||
|
||||
input_frames = [frames_dict[index] for index in fi]
|
||||
if script_net is None:
|
||||
script_net = torch.jit.trace(net,
|
||||
torch.as_tensor(np.stack(input_frames), dtype=torch.float32, device=DEVICE))
|
||||
|
||||
result_dict[output_index] = remove_many(input_frames, script_net)
|
||||
|
||||
# clean up the frame buffer
|
||||
for fdex in fi:
|
||||
del frames_dict[fdex]
|
||||
output_index += worker_nodes
|
||||
|
||||
|
||||
def capture_frames(file_path, frames_dict, prefetched_samples, total_frames):
|
||||
print(F"WORKER FRAMERIPPER ONLINE")
|
||||
for idx, frame in enumerate(iter_frames(file_path)):
|
||||
frames_dict[idx] = frame
|
||||
while len(frames_dict) > prefetched_samples:
|
||||
time.sleep(0.1)
|
||||
if idx > total_frames:
|
||||
break
|
||||
|
||||
|
||||
def matte_key(output, file_path,
|
||||
worker_nodes,
|
||||
gpu_batchsize,
|
||||
model_name,
|
||||
frame_limit=-1,
|
||||
prefetched_batches=4,
|
||||
framerate=-1):
|
||||
manager = multiprocessing.Manager()
|
||||
|
||||
results_dict = manager.dict()
|
||||
frames_dict = manager.dict()
|
||||
|
||||
print(file_path)
|
||||
|
||||
info = ffmpeg.probe(file_path)
|
||||
total_frames = int(info["streams"][0]["nb_frames"])
|
||||
|
||||
if frame_limit != -1:
|
||||
total_frames = min(frame_limit, total_frames)
|
||||
|
||||
fr = info["streams"][0]["r_frame_rate"]
|
||||
|
||||
if framerate == -1:
|
||||
print(F"FRAME RATE DETECTED: {fr} (if this looks wrong, override the frame rate)")
|
||||
framerate = math.ceil(eval(fr))
|
||||
|
||||
print(F"FRAME RATE: {framerate} TOTAL FRAMES: {total_frames}")
|
||||
|
||||
p = multiprocessing.Process(target=capture_frames,
|
||||
args=(file_path, frames_dict, gpu_batchsize * prefetched_batches, total_frames))
|
||||
p.start()
|
||||
|
||||
# note I am deliberatley not using pool
|
||||
# we can't trust it to run all the threads concurrently (or at all)
|
||||
workers = [multiprocessing.Process(target=worker,
|
||||
args=(worker_nodes, wn, results_dict, model_name, gpu_batchsize, total_frames,
|
||||
frames_dict))
|
||||
for wn in range(worker_nodes)]
|
||||
for w in workers:
|
||||
w.start()
|
||||
|
||||
command = None
|
||||
proc = None
|
||||
frame_counter = 0
|
||||
for i in range(math.ceil(total_frames / worker_nodes)):
|
||||
for wx in range(worker_nodes):
|
||||
|
||||
hash_index = i * worker_nodes + 1 + wx
|
||||
|
||||
while hash_index not in results_dict:
|
||||
time.sleep(0.1)
|
||||
|
||||
frames = results_dict[hash_index]
|
||||
# dont block access to it anymore
|
||||
del results_dict[hash_index]
|
||||
|
||||
for frame in frames:
|
||||
if command is None:
|
||||
command = ['nice', '-10',
|
||||
'ffmpeg',
|
||||
'-y',
|
||||
'-f', 'rawvideo',
|
||||
'-vcodec', 'rawvideo',
|
||||
'-s', F"{frame.shape[1]}x320",
|
||||
'-pix_fmt', 'gray',
|
||||
'-r', F"{framerate}",
|
||||
'-i', '-',
|
||||
'-an',
|
||||
'-vcodec', 'mpeg4',
|
||||
'-b:v', '2000k',
|
||||
'%s' % output]
|
||||
|
||||
proc = sp.Popen(command, stdin=sp.PIPE)
|
||||
|
||||
proc.stdin.write(frame.tostring())
|
||||
frame_counter = frame_counter + 1
|
||||
|
||||
if frame_counter >= total_frames:
|
||||
p.join()
|
||||
for w in workers:
|
||||
w.join()
|
||||
proc.stdin.close()
|
||||
proc.wait()
|
||||
print(F"FINISHED ALL FRAMES ({total_frames})!")
|
||||
|
||||
return
|
||||
|
||||
p.join()
|
||||
for w in workers:
|
||||
w.join()
|
||||
proc.stdin.close()
|
||||
proc.wait()
|
||||
return
|
||||
|
||||
|
||||
def transparentgif(output, file_path,
|
||||
worker_nodes,
|
||||
gpu_batchsize,
|
||||
model_name,
|
||||
frame_limit=-1,
|
||||
prefetched_batches=4,
|
||||
framerate=-1):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
temp_file = os.path.abspath("%s/matte.mp4" % tmpdirname)
|
||||
matte_key(temp_file, file_path,
|
||||
worker_nodes,
|
||||
gpu_batchsize,
|
||||
model_name,
|
||||
frame_limit,
|
||||
prefetched_batches,
|
||||
framerate)
|
||||
cmd = "nice -10 ffmpeg -y -i %s -i %s -filter_complex '[1][0]scale2ref[mask][main];[main][mask]alphamerge=shortest=1,fps=10,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse' -shortest %s" % (
|
||||
file_path, temp_file, output)
|
||||
sp.run(shlex.split(cmd))
|
||||
print("Process finished")
|
||||
|
||||
return
|
||||
|
||||
|
||||
def transparentgifwithbackground(output, overlay, file_path,
|
||||
worker_nodes,
|
||||
gpu_batchsize,
|
||||
model_name,
|
||||
frame_limit=-1,
|
||||
prefetched_batches=4,
|
||||
framerate=-1):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
temp_file = os.path.abspath("%s/matte.mp4" % tmpdirname)
|
||||
matte_key(temp_file, file_path,
|
||||
worker_nodes,
|
||||
gpu_batchsize,
|
||||
model_name,
|
||||
frame_limit,
|
||||
prefetched_batches,
|
||||
framerate)
|
||||
print("Starting alphamerge")
|
||||
cmd = "nice -10 ffmpeg -y -i %s -i %s -i %s -filter_complex '[1][0]scale2ref[mask][main];[main][mask]alphamerge=shortest=1[fg];[2][fg]overlay=(main_w-overlay_w)/2:(main_h-overlay_h)/2:format=auto,fps=10,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse' -shortest %s" % (
|
||||
file_path, temp_file, overlay, output)
|
||||
sp.run(shlex.split(cmd))
|
||||
print("Process finished")
|
||||
|
||||
return
|
||||
|
||||
|
||||
def transparentvideo(output, file_path,
|
||||
worker_nodes,
|
||||
gpu_batchsize,
|
||||
model_name,
|
||||
frame_limit=-1,
|
||||
prefetched_batches=4,
|
||||
framerate=-1):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
temp_file = os.path.abspath("%s/matte.mp4" % tmpdirname)
|
||||
matte_key(temp_file, file_path,
|
||||
worker_nodes,
|
||||
gpu_batchsize,
|
||||
model_name,
|
||||
frame_limit,
|
||||
prefetched_batches,
|
||||
framerate)
|
||||
print("Starting alphamerge")
|
||||
cmd = "nice -10 ffmpeg -y -nostats -loglevel 0 -i %s -i %s -filter_complex '[1][0]scale2ref[mask][main];[main][mask]alphamerge=shortest=1' -c:v qtrle -shortest %s" % (
|
||||
file_path, temp_file, output)
|
||||
process = sp.Popen(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
|
||||
stdout, stderr = process.communicate()
|
||||
print('after call')
|
||||
|
||||
if stderr:
|
||||
return "ERROR: %s" % stderr.decode("utf-8")
|
||||
print("Process finished")
|
||||
|
||||
return
|
||||
|
||||
|
||||
def transparentvideoovervideo(output, overlay, file_path,
|
||||
worker_nodes,
|
||||
gpu_batchsize,
|
||||
model_name,
|
||||
frame_limit=-1,
|
||||
prefetched_batches=4,
|
||||
framerate=-1):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
temp_file = os.path.abspath("%s/matte.mp4" % tmpdirname)
|
||||
matte_key(temp_file, file_path,
|
||||
worker_nodes,
|
||||
gpu_batchsize,
|
||||
model_name,
|
||||
frame_limit,
|
||||
prefetched_batches,
|
||||
framerate)
|
||||
print("Starting alphamerge")
|
||||
cmd = "nice -10 ffmpeg -y -i %s -i %s -i %s -filter_complex '[1][0]scale2ref[mask][main];[main][mask]alphamerge=shortest=1[vid];[vid][2:v]scale2ref[fg][bg];[bg][fg]overlay=shortest=1[out]' -map [out] -shortest %s" % (
|
||||
file_path, temp_file, overlay, output)
|
||||
sp.run(shlex.split(cmd))
|
||||
print("Process finished")
|
||||
return
|
||||
|
||||
|
||||
def transparentvideooverimage(output, overlay, file_path,
|
||||
worker_nodes,
|
||||
gpu_batchsize,
|
||||
model_name,
|
||||
frame_limit=-1,
|
||||
prefetched_batches=4,
|
||||
framerate=-1):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
temp_file = os.path.abspath("%s/matte.mp4" % tmpdirname)
|
||||
matte_key(temp_file, file_path,
|
||||
worker_nodes,
|
||||
gpu_batchsize,
|
||||
model_name,
|
||||
frame_limit,
|
||||
prefetched_batches,
|
||||
framerate)
|
||||
print("Scale image")
|
||||
temp_image = os.path.abspath("%s/new.jpg" % tmpdirname)
|
||||
cmd = "nice -10 ffmpeg -i %s -i %s -filter_complex 'scale2ref[img][vid];[img]setsar=1;[vid]nullsink' -q:v 2 %s" % (
|
||||
overlay, file_path, temp_image)
|
||||
sp.run(shlex.split(cmd))
|
||||
print("Starting alphamerge")
|
||||
cmd = "nice -10 ffmpeg -y -i %s -i %s -i %s -filter_complex '[0][1]scale2ref[img][vid];[img]setsar=1[img];[vid]nullsink; [img][2]overlay=(W-w)/2:(H-h)/2' -shortest %s" % (
|
||||
#cmd = "nice -10 ffmpeg -y -i %s -i %s -i %s -filter_complex '[1][0]scale2ref[mask][main];[main][mask]alphamerge=shortest=1[vid];[2:v][vid]overlay[out]' -map [out] -shortest %s" % (
|
||||
temp_image, file_path, temp_file, output)
|
||||
sp.run(shlex.split(cmd))
|
||||
print("Process finished")
|
||||
return
|
Loading…
x
Reference in New Issue
Block a user