1
+
2
+ import json
3
+ import os
4
+ from server import PromptServer
5
+ from aiohttp import web ,ClientSession
6
+ import asyncio
7
+ import zipfile
8
+ import tempfile
9
+
10
+ CURRENT_DIR = os .path .dirname (os .path .abspath (__file__ ))
11
+
12
+ root_dir = os .path .join (CURRENT_DIR , '..' )
13
+
14
+ cache_dir = os .path .join (root_dir , '.cache' )
15
+ docs_dir = os .path .join (root_dir , 'docs' )
16
+ setting_path = os .path .join (cache_dir , '.setting.json' )
17
+ device_id_path = os .path .join (cache_dir , '.cache_d_id' )
18
+
19
+ # 根据设备名称和设备信息 ip地址,创建一个唯一的设备ID
20
+ def create_device_id ():
21
+ import hashlib
22
+ import socket
23
+ import uuid
24
+
25
+ # 获取设备名称
26
+ device_name = socket .gethostname ()
27
+ # 获取设备IP地址
28
+ device_ip = socket .gethostbyname (device_name )
29
+ # 获取设备唯一标识
30
+ device_uuid = str (uuid .getnode ())
31
+ # 创建设备ID
32
+ device_id = hashlib .md5 ((device_name + device_ip + device_uuid ).encode ('utf-8' )).hexdigest ()
33
+ return device_id
34
+
35
+ if not os .path .exists (cache_dir ):
36
+ os .makedirs (cache_dir , 755 )
37
+
38
+ # 获取节点文档内容
39
+ def get_node_doc_file_content (node_name ):
40
+ file_path = os .path .join (docs_dir , node_name + '.md' )
41
+ if os .path .exists (file_path ):
42
+ with open (file_path , 'r' , encoding = 'utf-8' ) as file :
43
+ return file .read ()
44
+ else :
45
+ return ""
46
+
47
+ # 获取节点缓存文档
48
+ def get_node_cache_file_content (node_name ):
49
+ file_path = os .path .join (cache_dir , node_name + '.md' )
50
+ if os .path .exists (file_path ):
51
+ with open (file_path , 'r' , encoding = 'utf-8' ) as file :
52
+ return file .read ()
53
+ else :
54
+ return ""
55
+
56
+ # 写入缓存文件
57
+ def write_cache_file (node_name , content ):
58
+ file_path = os .path .join (cache_dir , node_name + '.md' )
59
+ print (file_path )
60
+ with open (file_path , 'w' , encoding = 'utf-8' ) as file :
61
+ file .write (content )
62
+
63
+ def write_device_id ():
64
+ # 创建设备ID
65
+ device_id = create_device_id ()
66
+ # 写入缓存文件
67
+ with open (device_id_path , 'w' , encoding = 'utf-8' ) as file :
68
+ file .write (device_id )
69
+ return device_id
70
+
71
+ write_device_id ()
72
+
73
+ def get_device_id ():
74
+ if os .path .exists (device_id_path ):
75
+ with open (device_id_path , 'r' , encoding = 'utf-8' ) as file :
76
+ return file .read ()
77
+ else :
78
+ divce_id = write_device_id ()
79
+ return divce_id
80
+
81
+ # Add route to fetch node info
82
+ @PromptServer .instance .routes .get ("/customnode/getNodeInfo" )
83
+ async def fetch_customnode_node_info (request ):
84
+ try :
85
+ node_name = request .rel_url .query ["nodeName" ]
86
+
87
+ if not node_name :
88
+ return web .json_response ({"content" : "" })
89
+
90
+ cache_content = get_node_cache_file_content (node_name )
91
+
92
+ if len (cache_content ) > 0 :
93
+ return web .json_response ({"content" : cache_content })
94
+
95
+ content = get_node_doc_file_content (node_name )
96
+ return web .json_response ({"content" : content })
97
+ except Exception as e :
98
+ return web .json_response ({"content" : "" })
99
+
100
+ # Add route to cache node info
101
+ @PromptServer .instance .routes .get ("/customnode/cacheNodeInfo" )
102
+ async def cache_customnode_node_info (request ):
103
+ try :
104
+ node_name = request .rel_url .query ["nodeName" ]
105
+
106
+ if not node_name :
107
+ return web .json_response ({"success" : False , "content" : "" })
108
+
109
+ print ('cache start' )
110
+ if not os .path .exists (os .path .join (cache_dir , node_name + '.md' )):
111
+ print ('cache file not exists' )
112
+ content = get_node_doc_file_content (node_name )
113
+ write_cache_file (node_name , content )
114
+ print ('cache success' )
115
+ return web .json_response ({"success" : True , "content" : content })
116
+ return web .json_response ({"success" : True , "content" : '' })
117
+ except Exception as e :
118
+ print (e )
119
+ return web .json_response ({"success" : False , "content" : '' })
120
+
121
+ # Add route to update node info
122
+ @PromptServer .instance .routes .post ("/customnode/updateNodeInfo" )
123
+ async def update_customnode_node_info (request ):
124
+ try :
125
+ json_data = await request .json ()
126
+ node_name = json_data ["nodeName" ]
127
+ # node_name = request.rel_url.query["nodeName"]
128
+ content = json_data ["content" ]
129
+
130
+ if not node_name :
131
+ return web .json_response ({"success" : False })
132
+
133
+ write_cache_file (node_name , content )
134
+
135
+ contribute = get_setting_item ('contribute' )
136
+ if contribute == True :
137
+ print ('send doc to cloud' )
138
+ asyncio .create_task (send_doc_to_cloud (node_name , content ))
139
+
140
+ return web .json_response ({"success" : True })
141
+ except Exception as e :
142
+ return web .json_response ({"success" : False })
143
+
144
+
145
+ # ================================== 以下是导出节点文档的代码 ==================================
146
+ def get_all_files (directory ):
147
+ """ 获取目录下所有文件的路径 """
148
+ file_paths = []
149
+ for root , _ , files in os .walk (directory ):
150
+ for file in files :
151
+ file_paths .append (os .path .join (root , file ))
152
+ return file_paths
153
+
154
+ def collect_unique_files (dir1 , dir2 ):
155
+ """ 收集两个目录下的所有文件,按文件名去重 """
156
+ unique_files = {}
157
+ for file_path in get_all_files (dir1 ) + get_all_files (dir2 ):
158
+ file_name = os .path .basename (file_path )
159
+ if file_name not in unique_files :
160
+ unique_files [file_name ] = file_path
161
+ return unique_files .values ()
162
+
163
+ def zip_files (file_paths , output_zip ):
164
+ """ 将文件打包成ZIP,并平铺存储 """
165
+ with zipfile .ZipFile (output_zip , 'w' ) as zipf :
166
+ for file_path in file_paths :
167
+ arcname = os .path .basename (file_path ) # 只使用文件名,不保留路径
168
+ zipf .write (file_path , arcname = arcname )
169
+
170
+ # Add route to export node info
171
+ @PromptServer .instance .routes .get ("/customnode/exportNodeInfo" )
172
+ async def export_customnode_node_info (request ):
173
+ # 把所有的缓存文件和docs文件夹下的文件打包成zip文件
174
+ """ 生成ZIP文件并返回文件流响应 """
175
+
176
+ unique_files = collect_unique_files (cache_dir , docs_dir )
177
+
178
+ # 使用临时文件存储ZIP
179
+ temp_zip = tempfile .NamedTemporaryFile (delete = False )
180
+ try :
181
+ with zipfile .ZipFile (temp_zip , 'w' ) as zipf :
182
+ for file_path in unique_files :
183
+ arcname = os .path .basename (file_path )
184
+ zipf .write (file_path , arcname = arcname )
185
+
186
+ temp_zip .seek (0 )
187
+
188
+ # 生成StreamResponse响应
189
+ response = web .StreamResponse ()
190
+ response .headers ['Content-Type' ] = 'application/zip'
191
+ response .headers ['Content-Disposition' ] = 'attachment; filename="output.zip"'
192
+
193
+ await response .prepare (request )
194
+
195
+ with open (temp_zip .name , 'rb' ) as f :
196
+ while chunk := f .read (8192 ):
197
+ await response .write (chunk )
198
+
199
+ await response .write_eof ()
200
+ print (response )
201
+ return response
202
+ except Exception as e :
203
+ print (e )
204
+ finally :
205
+ os .remove (temp_zip .name ) # 删除临时文件
206
+
207
+
208
+ # Add route to import node info
209
+ @PromptServer .instance .routes .post ("/customnode/importNodeInfo" )
210
+ async def import_customnode_node_info (request ):
211
+ # 接收上传的ZIP文件并解压到指定目录
212
+ """ 接收上传的ZIP文件并解压到指定目录 """
213
+ try :
214
+ data = await request .post ()
215
+ zip_file = data .get ('file' )
216
+
217
+ # 保存上传的ZIP文件
218
+ zip_path = os .path .join (CURRENT_DIR , 'upload.zip' )
219
+ with open (zip_path , 'wb' ) as f :
220
+ f .write (zip_file .file .read ())
221
+
222
+ # 解压ZIP文件到.cache目录,并且覆盖原有文件
223
+ with zipfile .ZipFile (zip_path , 'r' ) as zipf :
224
+ zipf .extractall (cache_dir )
225
+
226
+ os .remove (zip_path ) # 删除上传的ZIP文件
227
+ return web .json_response ({"success" : True })
228
+ except Exception as e :
229
+ return web .json_response ({"success" : False })
230
+
231
+
232
+ # 获取设置
233
+ def get_setting ():
234
+ if os .path .exists (setting_path ):
235
+ with open (setting_path , 'r' , encoding = 'utf-8' ) as file :
236
+ return file .read ()
237
+ else :
238
+ return "{}"
239
+
240
+ # 获取设置的每一项
241
+ def get_setting_item (key ):
242
+ setting = get_setting ()
243
+ setting_json = json .loads (setting )
244
+ return setting_json [key ]
245
+
246
+ # 保存设置到文件
247
+ def save_setting (setting ):
248
+ with open (setting_path , 'w' , encoding = 'utf-8' ) as file :
249
+ file .write (setting )
250
+
251
+ if not os .path .exists (setting_path ):
252
+ save_setting ('{"contribute": true}' )
253
+
254
+ # 更新设置到本地.setting.json
255
+ @PromptServer .instance .routes .post ("/customnode/updateSetting" )
256
+ async def update_setting (request ):
257
+ try :
258
+ json_data = await request .json ()
259
+ setting = json_data
260
+ # 跟缓存文件的设置项每一项对比,如果有不同则更新
261
+ cache_setting = get_setting ()
262
+ cache_setting_json = json .loads (cache_setting )
263
+ for key in setting :
264
+ cache_setting_json [key ] = setting [key ]
265
+
266
+ # 保存设置到文件,缩近2空格
267
+ save_setting (json .dumps (cache_setting_json , indent = 2 ))
268
+
269
+ return web .json_response ({"success" : True })
270
+ except Exception as e :
271
+ return web .json_response ({"success" : False })
272
+
273
+ # 发送当前文档到云
274
+ async def send_doc_to_cloud (node_type , content ):
275
+ url = 'http://comfy.zukmb.cn/api/saveNodesDocs'
276
+ async with ClientSession () as session :
277
+ device_id = get_device_id ()
278
+ data = {
279
+ 'device_id' : device_id ,
280
+ 'node_type' : node_type ,
281
+ 'content' : content
282
+ }
283
+ print ('send doc to cloud' )
284
+ async with session .post (url , data = data ) as response :
285
+ print (await response .text ())
286
+ # try:
287
+
288
+ # # url = 'http://localhost:8080/api/saveNodesDocs'
289
+ # # data = {
290
+ # # 'device_id': device_id,
291
+ # # 'node_type': node_type,
292
+ # # 'content': content
293
+ # # }
294
+ # # requests.post(url, data=data)
295
+ # except Exception as e:
296
+ # print(e)
0 commit comments