background remover as a library

This commit is contained in:
Ahmad Alobaid 2024-04-07 16:39:13 +03:00
parent ecd561b61f
commit 3e9804b8ed
6 changed files with 72 additions and 43 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea/
# Created by https://www.toptal.com/developers/gitignore/api/python
# Edit at https://www.toptal.com/developers/gitignore?templates=python

View File

@ -137,6 +137,27 @@ change the model for different background removal methods between `u2netp`, `u2n
backgroundremover -i "/path/to/video.mp4" -m "u2net_human_seg" -fl 150 -tv -o "output.mov"
```
## As a library
### Remove background image
```
from backgroundremover.bg import remove
def remove_bg(src_img_path, out_img_path):
model_choices = ["u2net", "u2net_human_seg", "u2netp"]
f = open(src_img_path, "rb")
data = f.read()
img = remove(data, model_name=model_choices[0],
alpha_matting=True,
alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10,
alpha_matting_erode_structure_size=10,
alpha_matting_base_size=1000)
f.close()
f = open(out_img_path, "wb")
f.write(img)
f.close()
```
## Todo
- convert logic from video to image to utilize more GPU on image removal

View File

@ -13,7 +13,7 @@ import torch.nn.functional
import torch.nn.functional
from hsh.library.hash import Hasher
from .u2net import detect, u2net
from . import utilities
from . import github
# closes https://github.com/nadermx/backgroundremover/issues/18
# closes https://github.com/nadermx/backgroundremover/issues/112
@ -56,7 +56,7 @@ class Net(torch.nn.Module):
if (
not os.path.exists(path)
):
utilities.download_files_from_github(
github.download_files_from_github(
path, model_name
)
@ -70,7 +70,7 @@ class Net(torch.nn.Module):
not os.path.exists(path)
#or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
):
utilities.download_files_from_github(
github.download_files_from_github(
path, model_name
)
@ -84,7 +84,7 @@ class Net(torch.nn.Module):
not os.path.exists(path)
#or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
):
utilities.download_files_from_github(
github.download_files_from_github(
path, model_name
)
else:

View File

@ -0,0 +1,38 @@
import os
import requests
def download_files_from_github(path, model_name):
if model_name not in ["u2net", "u2net_human_seg", "u2netp"]:
print("Invalid model name, please use 'u2net' or 'u2net_human_seg' or 'u2netp'")
return
print(f"downloading model [{model_name}] to {path} ...")
urls = []
if model_name == "u2net":
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2aa',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ab',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ac',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ad']
elif model_name == "u2net_human_seg":
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2haa',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2hab',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2hac',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2had']
elif model_name == 'u2netp':
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2netp.pth']
try:
os.makedirs(os.path.expanduser("~/.u2net"), exist_ok=True)
except Exception as e:
print(f"Error creating directory: {e}")
return
try:
with open(path, 'wb') as out_file:
for i, url in enumerate(urls):
print(f'downloading part {i+1} of {model_name}')
part_content = requests.get(url)
out_file.write(part_content.content)
print(f'finished downloading part {i+1} of {model_name}')
except Exception as e:
print(e)

View File

@ -8,7 +8,8 @@ from PIL import Image
from torchvision import transforms
from . import data_loader, u2net
from .. import utilities
from .. import github
def load_model(model_name: str = "u2net"):
hasher = Hasher()
@ -38,7 +39,7 @@ def load_model(model_name: str = "u2net"):
not os.path.exists(path)
#or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"
):
utilities.download_files_from_github(
github.download_files_from_github(
path, model_name
)
@ -48,11 +49,14 @@ def load_model(model_name: str = "u2net"):
"U2NET_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
print(f"DEBUG: path to be checked: {path}")
if (
not os.path.exists(path)
#or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
):
utilities.download_files_from_github(
github.download_files_from_github(
path, model_name
)
@ -66,7 +70,7 @@ def load_model(model_name: str = "u2net"):
not os.path.exists(path)
#or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
):
utilities.download_files_from_github(
github.download_files_from_github(
path, model_name
)

View File

@ -328,38 +328,3 @@ def transparentvideooverimage(output, overlay, file_path,
except PermissionError:
pass
return
def download_files_from_github(path, model_name):
if model_name not in ["u2net", "u2net_human_seg", "u2netp"]:
print("Invalid model name, please use 'u2net' or 'u2net_human_seg' or 'u2netp'")
return
print(f"downloading model [{model_name}] to {path} ...")
urls = []
if model_name == "u2net":
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2aa',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ab',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ac',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2ad']
elif model_name == "u2net_human_seg":
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2haa',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2hab',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2hac',
'https://github.com/nadermx/backgroundremover/raw/main/models/u2had']
elif model_name == 'u2netp':
urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2netp.pth']
try:
os.makedirs(os.path.expanduser("~/.u2net"), exist_ok=True)
except Exception as e:
print(f"Error creating directory: {e}")
return
try:
with open(path, 'wb') as out_file:
for i, url in enumerate(urls):
print(f'downloading part {i+1} of {model_name}')
part_content = requests.get(url)
out_file.write(part_content.content)
print(f'finished downloading part {i+1} of {model_name}')
except Exception as e:
print(e)