Skip to content

Commit

Permalink
refactor(hugging_face): 重构代码并适配 Hugging Face 格式
Browse files Browse the repository at this point in the history
- 重命名模型和适配器类,统一命名风格
- 移除不必要的模型导入和权重加载逻辑
- 简化推理函数,专注于 Cloud-Adapter 模型
- 更新配置文件和模型定义,使用新的类名
- 添加 requirements.txt 文件,列出项目依赖
  • Loading branch information
caixiaoshun committed Nov 22, 2024
1 parent 67e5d93 commit be8ce54
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 109 deletions.
13 changes: 13 additions & 0 deletions hugging_face/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
title: Cloud Adapter
emoji: 🏢
colorFrom: red
colorTo: purple
sdk: gradio
sdk_version: 5.6.0
app_file: app.py
pinned: false
license: apache-2.0
---

Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
97 changes: 6 additions & 91 deletions hugging_face/app.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,17 @@
from mmseg.apis import init_model
from typing import List
from glob import glob
from cloud_adapter.cloud_adapter_dinov2 import OursAdapterDinoVisionTransformer
from cloud_adapter.cloud_adapter_dinov2 import CloudAdapterDinoVisionTransformer
import numpy as np
from PIL import Image
from mmseg.models.segmentors.encoder_decoder import EncoderDecoder
import gradio as gr
import torch
import os
from cloud_adapter.cdnetv1 import CDnetV1
from cloud_adapter.cdnetv2 import CDnetV2
from cloud_adapter.dbnet import DBNet
from cloud_adapter.hrcloudnet import HRCloudNet
from cloud_adapter.kappamask import KappaMask
from cloud_adapter.mcdnet import MCDNet
from cloud_adapter.scnn import SCNN
from cloud_adapter.unetmobv2 import UNetMobV2


class CloudAdapterGradio:
def __init__(self, config_path=None, checkpoint_path=None, device="cpu", example_inputs=None, num_classes=2, palette=None, other_model_weight_path=None):
def __init__(self, config_path=None, checkpoint_path=None, device="cpu", example_inputs=None, num_classes=2, palette=None):
self.config_path = config_path
self.checkpoint_path = checkpoint_path
self.device = device
Expand All @@ -30,19 +22,6 @@ def __init__(self, config_path=None, checkpoint_path=None, device="cpu", example
self.img_size = 256 if num_classes == 2 else 512
self.palette = palette
self.legend = self.html_legend(num_classes=num_classes)

self.other_models = {
"cdnetv1": CDnetV1(num_classes=num_classes).to(self.device),
"cdnetv2": CDnetV2(num_classes=num_classes).to(self.device),
"hrcloudnet": HRCloudNet(num_classes=num_classes).to(self.device),
"mcdnet": MCDNet(in_channels=3, num_classes=num_classes).to(self.device),
"scnn": SCNN(num_classes=num_classes).to(self.device),
"dbnet": DBNet(img_size=self.img_size, in_channels=3, num_classes=num_classes).to(
self.device
),
"unetmobv2": UNetMobV2(num_classes=num_classes).to(self.device),
"kappamask": KappaMask(num_classes=num_classes, in_channels=3).to(self.device)
}
self.name_mapping = {
"KappaMask": "kappamask",
"CDNetv1": "cdnetv1",
Expand All @@ -55,19 +34,8 @@ def __init__(self, config_path=None, checkpoint_path=None, device="cpu", example
"Cloud-Adapter": "cloud-adapter",
}

self.load_weights(other_model_weight_path)

self.create_ui()

def load_weights(self, checkpoint_path: str):
for model_name, model in self.other_models.items():
weight_path = os.path.join(checkpoint_path, model_name+".bin")
weight_path = glob(weight_path)[0]
weight = torch.load(weight_path, map_location=self.device)
model.load_state_dict(weight)
model.eval()
print(f"Loaded {model_name} weights from {weight_path}")

def html_legend(self, num_classes=2):
if num_classes == 2:
return """
Expand Down Expand Up @@ -115,23 +83,6 @@ def create_ui(self):
type="pil",
)
with gr.Row():
# 增加一个下拉菜单
model_choice = gr.Dropdown(
choices=[
"Cloud-Adapter",
"DBNet",
"HRCloudNet",
"CDNetv2",
"UNetMobv2",
"CDNetv1",
"MCDNet",
"KappaMask",
"SCNN",
],
value="Cloud-Adapter",
label="Model",
elem_classes='model_type',
)
run_button = gr.Button(
'Run',
variant="primary",
Expand Down Expand Up @@ -161,43 +112,13 @@ def create_ui(self):
# 按钮点击逻辑:触发图像转换
run_button.click(
self.inference,
inputs=[in_image, model_choice],
inputs=in_image,
outputs=out_image,
)

@torch.no_grad()
def inference(self, image: Image.Image, model_choice: str) -> Image.Image:

if model_choice == "Cloud-Adapter":
return self.cloud_adapter_forward(image)
return self.simple_model_forward(image, self.name_mapping[model_choice])

@torch.no_grad()
def simple_model_forward(self, image: Image.Image, model_choice: str) -> Image.Image:
"""
Simple Model Inference
"""
ori_size = image.size
image = image.resize((self.img_size, self.img_size),
resample=Image.Resampling.BILINEAR)
image = np.array(image)
image = (image - np.min(image)) / (np.max(image)-np.min(image))

image = torch.from_numpy(image).unsqueeze(0).to(self.device)
image = image.permute(0, 3, 1, 2).float()

logits: torch.Tensor = self.other_models[model_choice].forward(image)
pred_mask = torch.argmax(logits, dim=1).squeeze(
0).cpu().numpy().astype(np.uint8)

del image
del logits
if torch.cuda.is_available():
torch.cuda.empty_cache()

im = Image.fromarray(pred_mask).convert("P")
im.putpalette(self.palette)
return im.resize(ori_size, resample=Image.Resampling.BILINEAR)
def inference(self, image: Image.Image) -> Image.Image:
return self.cloud_adapter_forward(image)

@torch.no_grad()
def cloud_adapter_forward(self, image: Image.Image) -> Image.Image:
Expand Down Expand Up @@ -254,7 +175,7 @@ def get_palette(dataset_name: str) -> List[int]:
l2a_examples = glob("example_inputs/l2a/*")
l8_examples = glob("example_inputs/l8/*")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = "cuda:1" if torch.cuda.is_available() else "cpu"
with gr.Blocks(analytics_enabled=False, title=title,css=custom_css) as demo:
gr.Markdown(f'# {title}')
with gr.Tabs():
Expand All @@ -266,7 +187,6 @@ def get_palette(dataset_name: str) -> List[int]:
example_inputs=hrc_whu_examples,
num_classes=2,
palette=get_palette("hrc_whu"),
other_model_weight_path="checkpoints/hrc_whu"
)
with gr.TabItem('Gaofen-1'):
CloudAdapterGradio(
Expand All @@ -276,7 +196,6 @@ def get_palette(dataset_name: str) -> List[int]:
example_inputs=gf1_examples,
num_classes=2,
palette=get_palette("gf12ms_whu_gf1"),
other_model_weight_path="checkpoints/gf12ms_whu_gf1"
)
with gr.TabItem('Gaofen-2'):
CloudAdapterGradio(
Expand All @@ -286,7 +205,6 @@ def get_palette(dataset_name: str) -> List[int]:
example_inputs=gf2_examples,
num_classes=2,
palette=get_palette("gf12ms_whu_gf2"),
other_model_weight_path="checkpoints/gf12ms_whu_gf2"
)

with gr.TabItem('Sentinel-2 (L1C)'):
Expand All @@ -297,7 +215,6 @@ def get_palette(dataset_name: str) -> List[int]:
example_inputs=l1c_examples,
num_classes=4,
palette=get_palette("cloudsen12_high_l1c"),
other_model_weight_path="checkpoints/cloudsen12_high_l1c"
)
with gr.TabItem('Sentinel-2 (L2A)'):
CloudAdapterGradio(
Expand All @@ -307,7 +224,6 @@ def get_palette(dataset_name: str) -> List[int]:
example_inputs=l2a_examples,
num_classes=4,
palette=get_palette("cloudsen12_high_l2a"),
other_model_weight_path="checkpoints/cloudsen12_high_l2a"
)
with gr.TabItem('Landsat-8'):
CloudAdapterGradio(
Expand All @@ -317,7 +233,6 @@ def get_palette(dataset_name: str) -> List[int]:
example_inputs=l8_examples,
num_classes=4,
palette=get_palette("l8_biome"),
other_model_weight_path="checkpoints/l8_biome"
)

demo.launch(share=True, debug=True)
1 change: 0 additions & 1 deletion hugging_face/checkpoints/.gitkeep

This file was deleted.

6 changes: 3 additions & 3 deletions hugging_face/cloud-adapter-configs/binary_classes_256x256.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
init_values=1e-05,
mlp_ratio=4,
num_heads=16,
ours_adapter_config=dict(
cloud_adapter_config=dict(
cnn_type='pmaa',
context_dim=64,
depth=4,
Expand All @@ -53,11 +53,11 @@
rank_dim=16,
return_last_feature=False,
return_multi_feats=False,
type='OursAdapter'),
type='CloudAdapter'),
patch_size=16,
proj_bias=True,
qkv_bias=True,
type='OursAdapterDinoVisionTransformer'),
type='CloudAdapterDinoVisionTransformer'),
data_preprocessor=dict(
bgr_to_rgb=True,
mean=[
Expand Down
6 changes: 3 additions & 3 deletions hugging_face/cloud-adapter-configs/multi_classes_512x512.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
init_values=1e-05,
mlp_ratio=4,
num_heads=16,
ours_adapter_config=dict(
cloud_adapter_config=dict(
cnn_type='pmaa',
context_dim=64,
depth=4,
Expand All @@ -53,11 +53,11 @@
rank_dim=16,
return_last_feature=False,
return_multi_feats=False,
type='OursAdapter'),
type='CloudAdapter'),
patch_size=16,
proj_bias=True,
qkv_bias=True,
type='OursAdapterDinoVisionTransformer'),
type='CloudAdapterDinoVisionTransformer'),
data_preprocessor=dict(
bgr_to_rgb=True,
mean=[
Expand Down
1 change: 1 addition & 0 deletions hugging_face/cloud_adapter/cdnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ def forward(self, x):
pred, pred_aux = output, auxout
pred = F.interpolate(pred, size, mode='bilinear', align_corners=True)
pred_aux = F.interpolate(pred_aux, size, mode='bilinear', align_corners=True)
return pred
return pred, pred_aux


Expand Down
2 changes: 1 addition & 1 deletion hugging_face/cloud_adapter/cloud_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def forward(self, x, cache, layer):


@MODELS.register_module()
class OursAdapter(nn.Module):
class CloudAdapter(nn.Module):
def __init__(self,
cnn_type="convnext", # convnext or mobilenet
int_type="convnext", # cross_attention or
Expand Down
18 changes: 9 additions & 9 deletions hugging_face/cloud_adapter/cloud_adapter_dinov2.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
from mmseg.models.builder import BACKBONES, MODELS
from torch import nn as nn
from .cloud_adapter import OursAdapter
from .cloud_adapter import CloudAdapter
from .dino_v2 import DinoVisionTransformer
from .utils import set_requires_grad, set_train
import torch
import torch.nn.functional as F


@BACKBONES.register_module()
class OursAdapterDinoVisionTransformer(DinoVisionTransformer):
class CloudAdapterDinoVisionTransformer(DinoVisionTransformer):
def __init__(
self,
ours_adapter_config=None,
cloud_adapter_config=None,
has_cat=False,
# [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, ],
adapter_index=[0, 6, 12, 18], # Transformer Block 的索引
**kwargs,
):
super().__init__(**kwargs)
self.ours_adapter: OursAdapter = MODELS.build(ours_adapter_config)
self.cloud_adapter: CloudAdapter = MODELS.build(cloud_adapter_config)
self.has_cat = has_cat
self.adapter_index = adapter_index

def forward_features(self, x, masks=None):
B, _, h, w = x.shape
cache = self.ours_adapter.cnn(x) # 得到多尺度特征或者单个特征
cache = self.cloud_adapter.cnn(x) # 得到多尺度特征或者单个特征
H, W = h // self.patch_size, w // self.patch_size
x = self.prepare_tokens_with_masks(x, masks)
outs = []
cur_idx = 0 # 交互模块的索引
for idx, blk in enumerate(self.blocks):
x = blk(x)
if idx in self.adapter_index:
x = self.ours_adapter.forward(
x = self.cloud_adapter.forward(
x,
cur_idx,
batch_first=True,
Expand Down Expand Up @@ -102,12 +102,12 @@ def forward(self, *args, **kwargs):
def train(self, mode: bool = True):
if not mode:
return super().train(mode)
set_requires_grad(self, ["ours_adapter"])
set_train(self, ["ours_adapter"])
set_requires_grad(self, ["cloud_adapter"])
set_train(self, ["cloud_adapter"])

def state_dict(self, destination, prefix, keep_vars):
state = super().state_dict(destination, prefix, keep_vars)
keys = [k for k in state.keys() if "ours_adapter" not in k]
keys = [k for k in state.keys() if "cloud_adapter" not in k]
for key in keys:
state.pop(key)
if key in destination:
Expand Down
1 change: 0 additions & 1 deletion hugging_face/cloud_adapter/mcdnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# @Email : 3038523973@qq.com
# @File : mcdnet.py
# @Software: PyCharm
import cv2
import image_dehazer
import numpy as np
# 论文地址:https://www.sciencedirect.com/science/article/pii/S1569843224001742?via%3Dihub
Expand Down
24 changes: 24 additions & 0 deletions hugging_face/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
einops==0.8.0
segmentation_models_pytorch
gradio==4.44.1
huggingface-hub==0.26.0
image_dehazer==0.0.9

mmcv==2.1.0
-f https://download.openmmlab.com/mmcv/dist/cpu/torch2.0/index.html

mmdet==3.3.0
mmengine==0.10.5
mmpretrain==1.2.0
mmsegmentation==1.2.2
timm==0.9.2


torch==2.0.1
-f https://download.pytorch.org/whl/cpu
torchvision==0.15.2
torchaudio==2.0.2
ftfy
regex
yacs
numpy==1.24.2

0 comments on commit be8ce54

Please sign in to comment.