background remover as a library
This commit is contained in:
parent
ecd561b61f
commit
3e9804b8ed
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
.idea/
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/python
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
||||
|
||||
|
21
README.md
21
README.md
@ -137,6 +137,27 @@ change the model for different background removal methods between `u2netp`, `u2n
|
||||
backgroundremover -i "/path/to/video.mp4" -m "u2net_human_seg" -fl 150 -tv -o "output.mov"
|
||||
```
|
||||
|
||||
## As a library
|
||||
### Remove background image
|
||||
|
||||
```
|
||||
from backgroundremover.bg import remove
|
||||
def remove_bg(src_img_path, out_img_path):
|
||||
model_choices = ["u2net", "u2net_human_seg", "u2netp"]
|
||||
f = open(src_img_path, "rb")
|
||||
data = f.read()
|
||||
img = remove(data, model_name=model_choices[0],
|
||||
alpha_matting=True,
|
||||
alpha_matting_foreground_threshold=240,
|
||||
alpha_matting_background_threshold=10,
|
||||
alpha_matting_erode_structure_size=10,
|
||||
alpha_matting_base_size=1000)
|
||||
f.close()
|
||||
f = open(out_img_path, "wb")
|
||||
f.write(img)
|
||||
f.close()
|
||||
```
|
||||
|
||||
## Todo
|
||||
|
||||
- convert logic from video to image to utilize more GPU on image removal
|
||||
|
@ -13,7 +13,7 @@ import torch.nn.functional
|
||||
import torch.nn.functional
|
||||
from hsh.library.hash import Hasher
|
||||
from .u2net import detect, u2net
|
||||
from . import utilities
|
||||
from . import github
|
||||
|
||||
# closes https://github.com/nadermx/backgroundremover/issues/18
|
||||
# closes https://github.com/nadermx/backgroundremover/issues/112
|
||||
@ -56,7 +56,7 @@ class Net(torch.nn.Module):
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
):
|
||||
utilities.download_files_from_github(
|
||||
github.download_files_from_github(
|
||||
path, model_name
|
||||
)
|
||||
|
||||
@ -70,7 +70,7 @@ class Net(torch.nn.Module):
|
||||
not os.path.exists(path)
|
||||
#or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
|
||||
):
|
||||
utilities.download_files_from_github(
|
||||
github.download_files_from_github(
|
||||
path, model_name
|
||||
)
|
||||
|
||||
@ -84,7 +84,7 @@ class Net(torch.nn.Module):
|
||||
not os.path.exists(path)
|
||||
#or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
|
||||
):
|
||||
utilities.download_files_from_github(
|
||||
github.download_files_from_github(
|
||||
path, model_name
|
||||
)
|
||||
else:
|
||||
|
38
backgroundremover/github.py
Normal file
38
backgroundremover/github.py
Normal file
@ -0,0 +1,38 @@
|
||||
import os
|
||||
import requests
|
||||
|
||||
|
||||
def download_files_from_github(path, model_name):
|
||||
if model_name not in ["u2net", "u2net_human_seg", "u2netp"]:
|
||||
print("Invalid model name, please use 'u2net' or 'u2net_human_seg' or 'u2netp'")
|
||||
return
|
||||
print(f"downloading model [{model_name}] to {path} ...")
|
||||
urls = []
|
||||
if model_name == "u2net":
|
||||
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2aa',
|
||||
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ab',
|
||||
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ac',
|
||||
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ad']
|
||||
elif model_name == "u2net_human_seg":
|
||||
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2haa',
|
||||
'https://github.com/nadermx/backgroundremover/raw/main/models/u2hab',
|
||||
'https://github.com/nadermx/backgroundremover/raw/main/models/u2hac',
|
||||
'https://github.com/nadermx/backgroundremover/raw/main/models/u2had']
|
||||
elif model_name == 'u2netp':
|
||||
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2netp.pth']
|
||||
try:
|
||||
os.makedirs(os.path.expanduser("~/.u2net"), exist_ok=True)
|
||||
except Exception as e:
|
||||
print(f"Error creating directory: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
with open(path, 'wb') as out_file:
|
||||
for i, url in enumerate(urls):
|
||||
print(f'downloading part {i+1} of {model_name}')
|
||||
part_content = requests.get(url)
|
||||
out_file.write(part_content.content)
|
||||
print(f'finished downloading part {i+1} of {model_name}')
|
||||
except Exception as e:
|
||||
print(e)
|
@ -8,7 +8,8 @@ from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from . import data_loader, u2net
|
||||
from .. import utilities
|
||||
from .. import github
|
||||
|
||||
|
||||
def load_model(model_name: str = "u2net"):
|
||||
hasher = Hasher()
|
||||
@ -38,7 +39,7 @@ def load_model(model_name: str = "u2net"):
|
||||
not os.path.exists(path)
|
||||
#or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"
|
||||
):
|
||||
utilities.download_files_from_github(
|
||||
github.download_files_from_github(
|
||||
path, model_name
|
||||
)
|
||||
|
||||
@ -48,11 +49,14 @@ def load_model(model_name: str = "u2net"):
|
||||
"U2NET_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||
)
|
||||
|
||||
print(f"DEBUG: path to be checked: {path}")
|
||||
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
#or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
|
||||
):
|
||||
utilities.download_files_from_github(
|
||||
github.download_files_from_github(
|
||||
path, model_name
|
||||
)
|
||||
|
||||
@ -66,7 +70,7 @@ def load_model(model_name: str = "u2net"):
|
||||
not os.path.exists(path)
|
||||
#or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
|
||||
):
|
||||
utilities.download_files_from_github(
|
||||
github.download_files_from_github(
|
||||
path, model_name
|
||||
)
|
||||
|
||||
|
@ -328,38 +328,3 @@ def transparentvideooverimage(output, overlay, file_path,
|
||||
except PermissionError:
|
||||
pass
|
||||
return
|
||||
|
||||
def download_files_from_github(path, model_name):
|
||||
if model_name not in ["u2net", "u2net_human_seg", "u2netp"]:
|
||||
print("Invalid model name, please use 'u2net' or 'u2net_human_seg' or 'u2netp'")
|
||||
return
|
||||
print(f"downloading model [{model_name}] to {path} ...")
|
||||
urls = []
|
||||
if model_name == "u2net":
|
||||
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2aa',
|
||||
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ab',
|
||||
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ac',
|
||||
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ad']
|
||||
elif model_name == "u2net_human_seg":
|
||||
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2haa',
|
||||
'https://github.com/nadermx/backgroundremover/raw/main/models/u2hab',
|
||||
'https://github.com/nadermx/backgroundremover/raw/main/models/u2hac',
|
||||
'https://github.com/nadermx/backgroundremover/raw/main/models/u2had']
|
||||
elif model_name == 'u2netp':
|
||||
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2netp.pth']
|
||||
try:
|
||||
os.makedirs(os.path.expanduser("~/.u2net"), exist_ok=True)
|
||||
except Exception as e:
|
||||
print(f"Error creating directory: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
with open(path, 'wb') as out_file:
|
||||
for i, url in enumerate(urls):
|
||||
print(f'downloading part {i+1} of {model_name}')
|
||||
part_content = requests.get(url)
|
||||
out_file.write(part_content.content)
|
||||
print(f'finished downloading part {i+1} of {model_name}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
Loading…
x
Reference in New Issue
Block a user