Stable Diffusion自动化生产落地指南
一、总体架构与关键组件
二、落地步骤与最小可用示例
{
"tasks": [
{
"task_id": "cat_01",
"prompt": "a cyberpunk cat, neon lights, 4k",
"negative_prompt": "blurry, low quality",
"seed": 42,
"steps": 30,
"cfg_scale": 7.5,
"model": "runwayml/stable-diffusion-v1-5",
"output_dir": "output/"
}
],
"global_settings": {
"enable_logging": true,
"max_retries": 3,
"parallel_workers": 2
}
}import json, os, time, logging
from PIL import Image, PngImagePlugin
from diffusers import StableDiffusionPipeline
from concurrent.futures import ThreadPoolExecutor, as_completed
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.FileHandler("pipeline.log"), logging.StreamHandler()]
)
class StableDiffusionAutomator:
def __init__(self, config_path):
self.config = self._load_config(config_path)
self._validate_config()
self.pipe = None
def _load_config(self, path):
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
def _validate_config(self):
required = ['task_id', 'prompt', 'model', 'output_dir']
for t in self.config['tasks']:
for f in required:
if f not in t:
raise ValueError(f"任务 {t.get('task_id','未知')} 缺少字段: {f}")
def _init_pipeline(self, model_id):
pipe = StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16, safety_checker=None
).to("cuda")
pipe.enable_attention_slicing()
if hasattr(pipe, "enable_vae_slicing"):
pipe.enable_vae_slicing()
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception:
logging.warning("xformers不可用,使用标准注意力")
self.pipe = pipe
def _generate_one(self, task):
import torch
task.setdefault('seed', -1)
task.setdefault('steps', 20)
task.setdefault('cfg_scale', 7.0)
task.setdefault('negative_prompt', '')
if self.pipe is None:
self._init_pipeline(task['model'])
g = torch.Generator("cuda").manual_seed(task['seed'] if task['seed'] != -1 else int(time.time()))
out_path = os.path.join(task['output_dir'], f"{task['task_id']}.png")
for attempt in range(self.config['global_settings']['max_retries']):
try:
image = self.pipe(
prompt=task['prompt'],
negative_prompt=task['negative_prompt'],
num_inference_steps=task['steps'],
guidance_scale=task['cfg_scale'],
generator=g
).images[0]
meta = PngImagePlugin.PngInfo()
meta.add_text("task_id", task['task_id'])
meta.add_text("prompt", task['prompt'])
meta.add_text("negative_prompt", task['negative_prompt'])
meta.add_text("seed", str(g.initial_seed()))
meta.add_text("steps", str(task['steps']))
meta.add_text("cfg_scale", str(task['cfg_scale']))
meta.add_text("model", task['model'])
image.save(out_path, pnginfo=meta)
logging.info(f"完成: {task['task_id']} -> {out_path}")
return True, task['task_id']
except Exception as e:
logging.warning(f"重试 {attempt+1}/{self.config['global_settings']['max_retries']} 失败: {e}")
time.sleep(1)
logging.error(f"任务失败: {task['task_id']}")
return False, task['task_id']
def run(self):
os.makedirs(self.config['global_settings'].get('log_dir','logs'), exist_ok=True)
tasks = self.config['tasks']
workers = min(self.config['global_settings']['parallel_workers'], len(tasks))
with ThreadPoolExecutor(max_workers=workers) as ex:
futs = {ex.submit(self._generate_one, t): t for t in tasks}
ok = fail = 0
for f in as_completed(futs):
s, _ = f.result()
ok += int(s); fail += int(not s)
logging.info(f"全部完成: 成功 {ok}, 失败 {fail}")
if __name__ == "__main__":
StableDiffusionAutomator("config.json").run()三、性能优化与规模化实践
四、企业级部署与运维要点
五、常见故障排查清单