1
+# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
+# except for the third-party components listed below.
3
+# Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
+# in the repsective licenses of these third-party components.
5
+# Users must comply with all terms and conditions of original licenses of these third-party
6
+# components and must ensure that the usage of the third party components adheres to
7
+# all relevant laws and regulations.
8
+
9
+# For avoidance of doubts, Hunyuan 3D means the large language models and
10
+# their software and algorithms, including trained model weights, parameters (including
11
+# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
+# fine-tuning enabling code and other elements of the foregoing made publicly available
13
+# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
+
1
15
"""
2
16
A model worker executes the model.
3
17
"""
22
36
from fastapi.responses import JSONResponse, FileResponse
23
37
24
38
from hy3dgen.rembg import BackgroundRemover
25
-from hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FloaterRemover, DegenerateFaceRemover, FaceReducer
39
+from hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FloaterRemover, DegenerateFaceRemover, FaceReducer, \
40
+ MeshSimplifier
26
41
from hy3dgen.texgen import Hunyuan3DPaintPipeline
27
42
from hy3dgen.text2image import HunyuanDiTPipeline
28
43
@@ -129,17 +144,31 @@ def load_image_from_base64(image):
129
144
130
145
131
146
class ModelWorker:
132
- def __init__(self, model_path='tencent/Hunyuan3D-2', device='cuda'):
147
+ def __init__(self,
148
+ model_path='tencent/Hunyuan3D-2mini',
149
+ tex_model_path='tencent/Hunyuan3D-2',
150
+ subfolder='hunyuan3d-dit-v2-mini-turbo',
151
+ device='cuda',
152
+ enable_tex=False):
133
153
self.model_path = model_path
134
154
self.worker_id = worker_id
135
155
self.device = device
136
156
logger.info(f"Loading the model {model_path} on worker {worker_id} ...")
137
157
138
158
self.rembg = BackgroundRemover()
139
- self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path, device=device)
140
- self.pipeline_t2i = HunyuanDiTPipeline('Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled',
141
- device=device)
142
- self.pipeline_tex = Hunyuan3DPaintPipeline.from_pretrained(model_path)
159
+ self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(
160
+ model_path,
161
+ subfolder=subfolder,
162
+ use_safetensors=True,
163
+ device=device,
164
+ )
165
+ self.pipeline.enable_flashvdm()
166
+ # self.pipeline_t2i = HunyuanDiTPipeline(
167
+ # 'Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled',
168
+ # device=device
169
+ # )
170
+ if enable_tex:
171
+ self.pipeline_tex = Hunyuan3DPaintPipeline.from_pretrained(tex_model_path)
143
172
144
173
def get_queue_length(self):
145
174
if model_semaphore is None:
@@ -174,31 +203,42 @@ def generate(self, uid, params):
174
203
else:
175
204
seed = params.get("seed", 1234)
176
205
params['generator'] = torch.Generator(self.device).manual_seed(seed)
177
- params['octree_resolution'] = params.get("octree_resolution", 256)
178
- params['num_inference_steps'] = params.get("num_inference_steps", 30)
179
- params['guidance_scale'] = params.get('guidance_scale', 7.5)
180
- params['mc_algo'] = 'mc'
206
+ params['octree_resolution'] = params.get("octree_resolution", 128)
207
+ params['num_inference_steps'] = params.get("num_inference_steps", 5)
208
+ params['guidance_scale'] = params.get('guidance_scale', 5.0)
209
+ params['mc_algo'] = 'dmc'
210
+ import time
211
+ start_time = time.time()
181
212
mesh = self.pipeline(**params)[0]
213
+ logger.info("--- %s seconds ---" % (time.time() - start_time))
182
214
183
215
if params.get('texture', False):
184
216
mesh = FloaterRemover()(mesh)
185
217
mesh = DegenerateFaceRemover()(mesh)
186
218
mesh = FaceReducer()(mesh, max_facenum=params.get('face_count', 40000))
187
219
mesh = self.pipeline_tex(mesh, image)
188
220
189
- with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as temp_file:
221
+ type = params.get('type', 'glb')
222
+ with tempfile.NamedTemporaryFile(suffix=f'.{type}', delete=True) as temp_file:
190
223
mesh.export(temp_file.name)
191
224
mesh = trimesh.load(temp_file.name)
192
- temp_file.close()
193
- os.unlink(temp_file.name)
194
- save_path = os.path.join(SAVE_DIR, f'{str(uid)}.glb')
225
+ save_path = os.path.join(SAVE_DIR, f'{str(uid)}.{type}')
195
226
mesh.export(save_path)
196
227
197
228
torch.cuda.empty_cache()
198
229
return save_path, uid
199
230
200
231
201
232
app = FastAPI()
233
+from fastapi.middleware.cors import CORSMiddleware
234
+
235
+app.add_middleware(
236
+ CORSMiddleware,
237
+ allow_origins=["*"], # 你可以指定允许的来源
238
+ allow_credentials=True,
239
+ allow_methods=["*"], # 允许所有方法
240
+ allow_headers=["*"], # 允许所有头部
241
+)
202
242
203
243
204
244
@app.post("/generate")
@@ -260,14 +300,17 @@ async def status(uid: str):
260
300
if __name__ == "__main__":
261
301
parser = argparse.ArgumentParser()
262
302
parser.add_argument("--host", type=str, default="0.0.0.0")
263
- parser.add_argument("--port", type=int, default=8081)
264
- parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2')
303
+ parser.add_argument("--port", type=str, default="8081")
304
+ parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2mini')
305
+ parser.add_argument("--tex_model_path", type=str, default='tencent/Hunyuan3D-2')
265
306
parser.add_argument("--device", type=str, default="cuda")
266
307
parser.add_argument("--limit-model-concurrency", type=int, default=5)
308
+ parser.add_argument('--enable_tex', action='store_true')
267
309
args = parser.parse_args()
268
310
logger.info(f"args: {args}")
269
311
270
312
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
271
313
272
- worker = ModelWorker(model_path=args.model_path, device=args.device)
314
+ worker = ModelWorker(model_path=args.model_path, device=args.device, enable_tex=args.enable_tex,
315
+ tex_model_path=args.tex_model_path)
273
316
uvicorn.run(app, host=args.host, port=args.port, log_level="info")