使用 FastAPI、Groq 和 Replicate 的 AI 图像生成器和故事生成应用程序
- Rifx.Online
- Programming , Technology/Web , Generative AI
- 08 Nov, 2024
项目介绍:AI 图像生成器和故事创作工具
AI 图像生成器和故事创作工具是一个网络应用程序,利用先进的 AI 技术为用户提供一个基于音频提示生成图像和故事的互动平台。该应用程序使用 FastAPI 作为后端,能够高效处理请求和响应,而前端则采用 HTML、CSS(DaisyUI 和 Tailwind CSS)和 JavaScript 构建,以提供响应式用户体验。该应用程序利用 llama-3.1–70b 进行提示生成,black-forest-labs/flux-1.1-pro 进行图像生成,以及 llava-v1.5–7b vbision 模型通过 Groq 和 Replicat.AI 分别进行故事创作。
主要特点:
- 音频录制和转录:用户可以录制他们的语音提示,然后使用语音识别技术将其转录为文本。
2. 图像生成:基于转录的文本,应用程序生成详细的图像提示,并使用 Replicate API 创建相应的图像。
3. 图像下载:用户可以将生成的图像下载到本地设备。
4. 故事生成:应用程序可以基于创建的图像生成引人入胜的故事,为视觉内容提供叙事背景。
5. 用户友好的界面:应用程序具有简洁直观的界面,使用户能够轻松与各种功能进行交互。
使用的技术:
- 后端:FastAPI, Groq, Replicate.ai, SpeechRecognition
- 前端:HTML, CSS (DaisyUI, Tailwind CSS), JavaScript
- 图像处理:用于图像处理的Pillow
- 异步操作:用于高效文件处理和网络请求的aiohttp和aiofiles
该项目作为将多个AI服务集成到一个统一应用程序中的示例,使用户能够探索AI生成内容的创意可能性。
代码库详细说明:
- 前端 (HTML/JavaScript):
- 该应用程序使用单个 HTML 页面 (index.html),采用响应式设计,使用 DaisyUI 和 Tailwind CSS。
- 页面包含音频录制、转录、提示生成、图像生成和故事生成的部分。
- JavaScript 文件 (script.js) 处理用户交互并与后端 API 通信。
2. 后端 (FastAPI) :
- 主要应用程序定义在 app/main.py 中。
- 它使用 FastAPI 创建一个具有多个端点的 Web 服务器:
— a. /: 提供主要 HTML 页面。
— b. /transcribe: 将音频转录为文本。
— c. /generate_prompt: 使用 Groq 的 LLM 从文本生成图像提示。
— d. /generate_image: 使用 Replicate 的 Flux 模型生成图像。
— e. /download_image: 下载并保存生成的图像。
— f. /generate_story_from_image: 根据图像使用 Groq 的 LLaVA 模型生成故事。
— g. /download/{filename}: 提供下载的图像文件。
3. 主要功能:
- 音频录制和转录
- 文本到图像的提示生成
- 根据提示生成图像
- 根据图像生成故事
- 图像下载和保存
4. 外部 API:
- Groq: 用于文本生成(调整后的提示和 故事)
- Replicate AI: black-forest-labs/flux-1.1-pro 模型用于图像生成
- 需要安装的必要包:
fastapi
uvicorn
jinja2
python-multipart
pydantic
python-dotenv
groq
replicate
SpeechRecognition
pydub
aiohttp
aiofiles
Pillow
您可以使用 pip 安装这些包:
pip install fastapi uvicorn jinja2 python-multipart pydantic python-dotenv groq replicate SpeechRecognition pydub aiohttp aiofiles Pillow
执行说明:
- 设置环境变量: 在根目录创建一个 .env 文件,内容如下:
GROQ_API_KEY=your_groq_api_key_here
REPLICATE_API_TOKEN=your_replicate_api_token_here
用您的实际 API 密钥替换占位符值。
- 确保您已准备好所有必要的文件:
- — app/main.py
- — app/config.py
- — app/utils.py
- — templates/index.html
- — static/css/styles.css
- — static/js/script.js
- 运行 FastAPI 服务器: 导航到包含 app/main.py 的目录并运行:
uvicorn app.main:app - reload
- 访问应用程序:
- — 打开浏览器并转到 http://127.0.0.1:8000
- 使用应用程序:
- — a. 点击“开始录音”并说出您的提示。
- — b. 完成后点击“停止录音”。
- — c. 音频将自动转录。
- — d. 点击“生成图像提示”以创建详细提示。
- — e. 点击“生成图像”以根据提示创建图像。
- — f. 使用“下载图像”按钮保存生成的图像。
- — g. 点击“生成故事”以根据生成的图像创建故事。
注意:确保您有良好的互联网连接,因为该应用程序依赖外部 API 提供各种功能。
该应用程序展示了各种 AI 技术的复杂集成,包括语音识别、语言模型和图像生成,所有这些都封装在用户友好的 Web 界面中。
如下面所示的 FastAPI UI
AI图像生成应用
说出您的提示
- 开始录音
- 停止录音
- 转录(转录文本):一位美丽的印度模特在时装秀中走下Runway Ram
- 根据转录文本创建一个新的提示以生成图像
- — 生成的提示:“生成一幅高度逼真的图像,描绘一位令人惊艳的印度模特在标志性的Runway Ram上走动,作为高端时装秀的一部分。这位模特是一位22岁的印度女性,拥有长长的黑发、深棕色的眼睛和无瑕的肌肤,身穿精美、复杂刺绣的lehenga choli,配以金银亮片,传统印度服饰,并搭配高跟鞋。她的服装设计精美,刺绣细致。强调优雅的褶皱、闪亮的面料,以及她优雅的姿态和自信的步伐。配饰方面,加入华丽的珠宝,如手链、金手镯和项链,装饰在她的手、脖子和一侧的发型上。灯光效果起着重要作用,设置温暖的舞台聚光灯,突出模特的服装,并用柔和的蓝色调照亮整个环境。摄影角度应全面展示服装的细节。期望的场景视角是模特正面全身照,走道周围被强烈的金色光线照亮,光线从内部散发出来。”
生成的图像
从图像生成的故事
这位身着金银亮片服装和长裙的选美皇后的惊艳景象足以让任何观众陶醉。Sashaya Gnanavel在前景中显得格外引人注目,自信地走在T台上,吸引着现场观众的目光。她时尚的服装,搭配优雅的珍珠项链,吸引了在场每一个人的注意。这个系列展示了鲜艳的色彩和闪亮的刺绣,增加了活动的整体视觉吸引力。Sashaya在聚光灯下的自信与美丽,真实地证明了她在时尚行业的才华和奉献。她的妆容、珠宝和精美服装所创造的耀眼效果,为设计和工艺的非凡展示奠定了舞台。这个引人注目的场景 encapsulates 魔力和奢华,让观众对这一切的绝美感到惊叹。
代码实现
创建虚拟环境
要使用 Python 的 venv 模块创建虚拟环境,请按照以下步骤操作:
- 打开终端或命令提示符。
- 导航到您的项目目录(您想要创建虚拟环境的地方)。您可以使用 cd 命令更改目录。例如:
cd path/to/your/project
- 通过运行以下命令创建虚拟环境:
python -m venv venv
- 此命令将在您的项目文件夹中创建一个名为 venv 的新目录,该目录将包含虚拟环境(在 Windows 上)
- 激活虚拟环境:
venv\Scripts\activate
- 文件夹结构
- utils.py
import base64
import os
from pydub import AudioSegment
def save_audio(audio_data):
# Decode the base64 audio data
audio_bytes = base64.b64decode(audio_data.split(",")[1])
# Save the audio to a temporary file
temp_file = "temp_audio.webm"
with open(temp_file, "wb") as f:
f.write(audio_bytes)
# Convert WebM to WAV
audio = AudioSegment.from_file(temp_file, format="webm")
wav_file = "temp_audio.wav"
audio.export(wav_file, format="wav")
# Remove the temporary WebM file
os.remove(temp_file)
return wav_file
def text_to_speech(text):
# Implement text-to-speech functionality if needed
pass
- main.py
"""
1. 通过麦克风录制音频
2. 将音频转录为文本
3. 使用 Groq Llama3 模型生成图像提示
4. 使用 Replicate.ai Flux 模型生成图像
5. 显示生成的图像
6. 下载生成的图像
应用程序使用 DaisyUI 和 Tailwind CSS 进行样式设置,提供黑暗模式界面。布局是响应式的,应该在桌面和移动设备上都能良好运行。
注意:您可能需要根据您使用的特定 API 和模型以及部署环境的安全性考虑来调整代码的某些部分。
"""
from fastapi import FastAPI, Request, HTTPException
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel
import speech_recognition as sr
from groq import Groq
import replicate
import os
import aiohttp
import aiofiles
import time
from dotenv import load_dotenv
load_dotenv()
from .utils import text_to_speech, save_audio
from PIL import Image
import io
import base64
import base64
## Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
## Initialize Groq client with the API key
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY is not set in the environment variables")
groq_client = Groq(api_key=GROQ_API_KEY)
class AudioData(BaseModel):
audio_data: str
class ImagePrompt(BaseModel):
prompt: str
class PromptRequest(BaseModel):
text: str
## Add this new model
class FreeImagePrompt(BaseModel):
prompt: str
image_path: str
@app.get("/")
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/transcribe")
async def transcribe_audio(audio_data: AudioData):
try:
# Save the audio data to a file
audio_file = save_audio(audio_data.audio_data)
# Transcribe the audio
recognizer = sr.Recognizer()
with sr.AudioFile(audio_file) as source:
audio = recognizer.record(source)
text = recognizer.recognize_google(audio)
return JSONResponse(content={"text": text})
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/generate_prompt")
async def generate_prompt(prompt_request: PromptRequest):
try:
text = prompt_request.text
# Use Groq to generate a new prompt
response = groq_client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a creative assistant that generates prompts for realistic image generation."},
{"role": "user", "content": f"Generate a detailed prompt for a realistic image based on this description: {text}.The prompt should be clear and detailed in no more than 200 words."}
],
model="llama-3.1-70b-versatile",
max_tokens=256
)
generated_prompt = response.choices[0].message.content
print(f"tweaked prompt:{generated_prompt}")
return JSONResponse(content={"prompt": generated_prompt})
except Exception as e:
print(f"Error generating prompt: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
@app.post("/generate_image")
async def generate_image(image_prompt: ImagePrompt):
try:
prompt = image_prompt.prompt
print(f"Received prompt: {prompt}")
# Use Replicate to generate an image
output = replicate.run(
"black-forest-labs/flux-1.1-pro",
input={
"prompt": prompt,
"aspect_ratio": "1:1",
"output_format": "jpg",
"output_quality": 80,
"safety_tolerance": 2,
"prompt_upsampling": True
}
)
print(f"Raw output: {output}")
print(f"Output type: {type(output)}")
# Convert the FileOutput object to a string
image_url = str(output)
print(f"Generated image URL: {image_url}")
return JSONResponse(content={"image_url": image_url})
except Exception as e:
print(f"Error generating image: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
@app.get("/download_image")
async def download_image(image_url: str):
try:
# Create Output folder if it doesn't exist
output_folder = "Output"
os.makedirs(output_folder, exist_ok=True)
# Generate a unique filename
filename = f"generated_image_{int(time.time())}.jpg"
filepath = os.path.join(output_folder, filename)
# Download the image
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as resp:
if resp.status == 200:
async with aiofiles.open(filepath, mode='wb') as f:
await f.write(await resp.read())
# Return the filepath and filename
return JSONResponse(content={
"filepath": filepath,
"filename": filename
})
except Exception as e:
print(f"Error downloading image: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
class StoryRequest(BaseModel):
filepath: str
filename: str
@app.post("/generate_story_from_image")
async def generate_story_from_image(content: StoryRequest):
try:
image_path = content.filepath
print(f"Image path: {image_path}")
# Check if the file exists
if not os.path.exists(image_path):
raise HTTPException(status_code=400, detail="Image file not found")
# Getting the base64 string
base64_image = encode_image(image_path)
client = Groq()
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Generate a clear,concise,meaningful and engaging cover story for a highly acclaimed leisure magazine based on the image provided. The story should keep the audience glued and engaged and the story should bewithin 200 words."},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
},
},
],
}
],
model="llava-v1.5-7b-4096-preview",
)
story = chat_completion.choices[0].message.content
print(f"Generated story: {story}")
return JSONResponse(content={"story": story})
except Exception as e:
print(f"Error generating story from the image: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
@app.get("/download/{filename}")
async def serve_file(filename: str):
file_path = os.path.join("Output", filename)
return FileResponse(file_path, filename=filename)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
- script.js
let mediaRecorder;
let audioChunks = [];
const startRecordingButton = document.getElementById('startRecording');
const stopRecordingButton = document.getElementById('stopRecording');
const recordingStatus = document.getElementById('recordingStatus');
const transcription = document.getElementById('transcription');
const generatePromptButton = document.getElementById('generatePrompt');
const generatedPrompt = document.getElementById('generatedPrompt');
const generateImageButton = document.getElementById('generateImage');
const generatedImage = document.getElementById('generatedImage');
const downloadLink = document.getElementById('downloadLink');
const generateStoryButton = document.getElementById('generateStory');
const generatedStory = document.getElementById('generatedStory');
startRecordingButton.addEventListener('click', startRecording);
stopRecordingButton.addEventListener('click', stopRecording);
generatePromptButton.addEventListener('click', generatePrompt);
generateImageButton.addEventListener('click', generateImage);
generateStoryButton.addEventListener('click', generateStory);
async function startRecording() {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.ondataavailable = (event) => {
audioChunks.push(event.data);
};
mediaRecorder.onstop = sendAudioToServer;
mediaRecorder.start();
startRecordingButton.disabled = true;
stopRecordingButton.disabled = false;
recordingStatus.textContent = 'Recording...';
}
function stopRecording() {
mediaRecorder.stop();
startRecordingButton.disabled = false;
stopRecordingButton.disabled = true;
recordingStatus.textContent = 'Recording stopped.';
}
async function sendAudioToServer() {
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
const reader = new FileReader();
reader.readAsDataURL(audioBlob);
reader.onloadend = async () => {
const base64Audio = reader.result;
const response = await fetch('/transcribe', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ audio_data: base64Audio }),
});
const data = await response.json();
transcription.textContent = `Transcription: ${data.text}`;
generatePromptButton.disabled = false;
};
audioChunks = [];
}
async function generatePrompt() {
const text = transcription.textContent.replace('Transcription: ', '');
const response = await fetch('/generate_prompt', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ text: text }),
});
const data = await response.json();
generatedPrompt.textContent = `Generated Prompt: ${data.prompt}`;
generateImageButton.disabled = false;
}
async function generateImage() {
const prompt = generatedPrompt.textContent.replace('Generated Prompt: ', '');
const response = await fetch('/generate_image', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ prompt: prompt }),
});
const data = await response.json();
generatedImage.src = data.image_url;
// Download the image and get the filepath
const downloadResponse = await fetch(`/download_image?image_url=${encodeURIComponent(data.image_url)}`);
const downloadData = await downloadResponse.json();
// Store the filepath and filename for later use
generatedImage.dataset.filepath = downloadData.filepath;
generatedImage.dataset.filename = downloadData.filename;
// Set up the download link
downloadLink.href = `/download/${downloadData.filename}`;
downloadLink.download = downloadData.filename;
downloadLink.style.display = 'inline-block';
}
async function generateStory() {
const imagePath = generatedImage.dataset.filepath;
const filename = generatedImage.dataset.filename;
if (!imagePath || !filename) {
generatedStory.textContent = "Error: Please generate an image first.";
return;
}
try {
const response = await fetch('/generate_story_from_image', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ filepath: imagePath, filename: filename }),
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data = await response.json();
// Display the generated story
generatedStory.textContent = data.story;
// Make sure the story container is visible
document.getElementById('storyContainer').style.display = 'block';
} catch (error) {
console.error('Error:', error);
generatedStory.textContent = `Error: ${error.message}`;
}
}
// Modify the download link click event
downloadLink.addEventListener('click', async (event) => {
event.preventDefault();
const response = await fetch(downloadLink.href);
const blob = await response.blob();
const url = window.URL.createObjectURL(blob);
const a = document.createElement('a');
a.style.display = 'none';
a.href = url;
a.download = response.headers.get('Content-Disposition').split('filename=')[1];
document.body.appendChild(a);
a.click();
window.URL.revokeObjectURL(url);
});
- style.css
body {
background-color: #1a1a2e;
color: #ffffff;
}
.container {
max-width: 1200px;
}
#imageContainer {
min-height: 300px;
display: flex;
align-items: center;
justify-content: center;
background-color: #16213e;
border-radius: 8px;
}
#generatedImage {
max-width: 100%;
max-height: 400px;
object-fit: contain;
}
- index.html
<!DOCTYPE html>
<html lang="en" data-theme="dark">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI 图像生成器</title>
<link href="https://cdn.jsdelivr.net/npm/daisyui@3.7.3/dist/full.css" rel="stylesheet" type="text/css" />
<script src="https://cdn.tailwindcss.com"></script>
<link rel="stylesheet" href="{{ url_for('static', path='/css/styles.css') }}">
</head>
<body>
<div class="container mx-auto px-4 py-8">
<h1 class="text-4xl font-bold mb-8 text-center">AI 图像生成器</h1>
<div class="grid grid-cols-1 md:grid-cols-2 gap-8">
<div class="card bg-base-200 shadow-xl">
<div class="card-body">
<h2 class="card-title mb-4">说出你的提示</h2>
<button id="startRecording" class="btn btn-primary mb-4">开始录音</button>
<button id="stopRecording" class="btn btn-secondary mb-4" disabled>停止录音</button>
<div id="recordingStatus" class="text-lg mb-4"></div>
<div id="transcription" class="text-lg mb-4"></div>
<button id="generatePrompt" class="btn btn-accent mb-4" disabled>生成图像提示</button>
<div id="generatedPrompt" class="text-lg mb-4"></div>
<button id="generateImage" class="btn btn-success" disabled>生成图像</button>
</div>
</div>
<div class="card bg-base-200 shadow-xl">
<div class="card-body">
<h2 class="card-title mb-4">生成的图像</h2>
<div id="imageContainer" class="mb-4">
<img id="generatedImage" src="" alt="生成的图像" class="w-full h-auto">
</div>
<a id="downloadLink" href="#" download="generated_image.png" class="btn btn-info" style="display: none;">下载图像</a>
</div>
</div>
</div>
<!-- Add this new section after the existing cards -->
<div class="card bg-base-200 shadow-xl mt-8">
<div class="card-body">
<h2 class="card-title mb-4">从图像生成故事</h2>
<button id="generateStory" class="btn btn-primary mb-4">生成故事</button>
<div id="storyContainer" class="mb-4">
<p id="generatedStory" class="text-lg"></p>
</div>
</div>
</div>
</div>
<script src="{{ url_for('static', path='/js/script.js') }}"></script>
</body>
</html>
结论
AI图像生成器和故事创作项目成功整合了各种AI技术,创建了一个互动式Web应用程序,允许用户根据音频提示生成图像和故事。通过利用FastAPI作为后端和现代前端技术,该应用程序提供了无缝的用户体验。
关键要点:
- AI模型的集成:该项目展示了如何集成多个AI模型,包括用于文本生成的Groq和用于图像生成的Replicate,以创建一个增强用户创意的统一应用程序。
- 用户互动:该应用程序允许用户通过语音命令进行互动,使其易于访问且用户友好。录音、转录以及基于该输入生成内容的能力展示了语音驱动应用程序的潜力。
- 动态内容生成:通过根据用户输入动态生成图像和故事,该应用程序突显了AI在内容创作中的能力,为用户提供独特且个性化的输出。
- 响应式设计:使用DaisyUI和Tailwind CSS确保应用程序在视觉上吸引人且响应迅速,适应各种设备上的用户。
- 未来增强:该项目可以通过加入额外功能进一步增强,例如用户身份验证、保存用户生成的内容,以及扩展用于不同创意任务的AI模型范围。
总体而言,该项目作为一个综合示例,展示了如何构建一个结合音频处理、图像生成和故事讲述的AI驱动的网络应用程序,为创意领域的创新应用铺平了道路。
参考文献
- FastAPI 文档: FastAPI 是一个现代的 web 框架,用于使用 Python 构建 API。它旨在易于使用且快速。
- Pydantic: Pydantic 用于数据验证和使用 Python 类型注释的设置管理。
- Groq: Groq 是一个构建和部署 AI 模型的平台。它提供文本生成和其他 AI 任务的 API。
- Replicate: Replicate 是一个允许您在云中运行机器学习模型的平台。它提供各种模型的 API,包括图像生成。
- SpeechRecognition: SpeechRecognition 是一个执行语音识别的库,支持多种引擎和 API。
- Pillow: Pillow 是 Python Imaging Library (PIL) 的一个分支,为您的 Python 代码添加图像处理能力。
- JavaScript Fetch API: Fetch API 提供了一种现代方式在 JavaScript 中发起网络请求。
- HTML5 音频 API: HTML5 Audio API 允许您在 web 应用程序中播放音频文件。
- DaisyUI: DaisyUI 是一个为 Tailwind CSS 提供预设计组件的组件库。
- Tailwind CSS: Tailwind CSS 是一个实用优先的 CSS 框架,用于创建自定义设计,而无需离开 HTML。