Skip to content

Commit

Permalink
fix(app): 更新权重文件路径并调整 GPU 设备
Browse files Browse the repository at this point in the history
- 将所有权重文件的扩展名从 .bin 改为 .pth
- 将 GPU 设备从 "cuda:1" 修改为 "cuda:0"
- 更新了多个卫星数据处理模块的权重文件路径
  • Loading branch information
caixiaoshun committed Nov 22, 2024
1 parent edea9ff commit e312248
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions hugging_face/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,14 @@ def get_palette(dataset_name: str) -> List[int]:
l2a_examples = glob("example_inputs/l2a/*")
l8_examples = glob("example_inputs/l8/*")

device = "cuda:1" if torch.cuda.is_available() else "cpu"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
with gr.Blocks(analytics_enabled=False, title=title,css=custom_css) as demo:
gr.Markdown(f'# {title}')
with gr.Tabs():
with gr.TabItem('Google Earth'):
CloudAdapterGradio(
config_path="cloud-adapter-configs/binary_classes_256x256.py",
checkpoint_path="checkpoints/cloud-adapter/hrc_whu_full_weight.bin",
checkpoint_path="checkpoints/cloud-adapter/hrc_whu_full_weight.pth",
device=device,
example_inputs=hrc_whu_examples,
num_classes=2,
Expand All @@ -180,7 +180,7 @@ def get_palette(dataset_name: str) -> List[int]:
with gr.TabItem('Gaofen-1'):
CloudAdapterGradio(
config_path="cloud-adapter-configs/binary_classes_256x256.py",
checkpoint_path="checkpoints/cloud-adapter/gf1_full_weight.bin",
checkpoint_path="checkpoints/cloud-adapter/gf1_full_weight.pth",
device=device,
example_inputs=gf1_examples,
num_classes=2,
Expand All @@ -189,7 +189,7 @@ def get_palette(dataset_name: str) -> List[int]:
with gr.TabItem('Gaofen-2'):
CloudAdapterGradio(
config_path="cloud-adapter-configs/binary_classes_256x256.py",
checkpoint_path="checkpoints/cloud-adapter/gf2_full_weight.bin",
checkpoint_path="checkpoints/cloud-adapter/gf2_full_weight.pth",
device=device,
example_inputs=gf2_examples,
num_classes=2,
Expand All @@ -199,7 +199,7 @@ def get_palette(dataset_name: str) -> List[int]:
with gr.TabItem('Sentinel-2 (L1C)'):
CloudAdapterGradio(
config_path="cloud-adapter-configs/multi_classes_512x512.py",
checkpoint_path="checkpoints/cloud-adapter/l1c_full_weight.bin",
checkpoint_path="checkpoints/cloud-adapter/l1c_full_weight.pth",
device=device,
example_inputs=l1c_examples,
num_classes=4,
Expand All @@ -208,7 +208,7 @@ def get_palette(dataset_name: str) -> List[int]:
with gr.TabItem('Sentinel-2 (L2A)'):
CloudAdapterGradio(
config_path="cloud-adapter-configs/multi_classes_512x512.py",
checkpoint_path="checkpoints/cloud-adapter/l2a_full_weight.bin",
checkpoint_path="checkpoints/cloud-adapter/l2a_full_weight.pth",
device=device,
example_inputs=l2a_examples,
num_classes=4,
Expand All @@ -217,7 +217,7 @@ def get_palette(dataset_name: str) -> List[int]:
with gr.TabItem('Landsat-8'):
CloudAdapterGradio(
config_path="cloud-adapter-configs/multi_classes_512x512.py",
checkpoint_path="checkpoints/cloud-adapter/l8_full_weight.bin",
checkpoint_path="checkpoints/cloud-adapter/l8_full_weight.pth",
device=device,
example_inputs=l8_examples,
num_classes=4,
Expand Down

0 comments on commit e312248

Please sign in to comment.