Build a Photo Restoration App Using Replicate, Next.js and Upstash
For this blog post, we'll make a few assumptions before continuing, but you should ideally have:
- An Upstash account where you have a Redis instance created
- An Replicate account with access to your API token
- A Next.js project to implement our desired functionality
- A Vercel account to deploy your project to
What is this?
Have you been wanting to get started using machine learning to generate images from the models available on Replicate? Well, in this tutorial, we'll explore Replicate's wide range of hosted models and Upstash's Redis. Not only will we explore these models, but we'll walk through the process of setting one up, and touch on how you can easily update the implementation to use other models too.
In this tutorial, we'll cover the usage of Microsoft's Bringing Old Photos Back to Life model, which essentially takes an old photo, runs it through the model, and outputs an edited, and hopefully improved version of your photo.
What's the architecture?
If you have some React experience, you should be able to determine how the app architecture works simply by reading through the codebase, but to make it that little bit easier, or if you simply prefer to see an overview, there is one provided below.
What do I need to start?
To start with, you'll of course need a Next.js project. This can be done by following the Next.js setup guide here, or if you have one already setup, then that's okay too. In this tutorial, we're also using Tailwind CSS, but you can of course use any form of styling setup that you prefer.
Now that we have a basic Next.js project setup, we can still Upstash's Redis library by running the command:
npm install @upstash/redis
npm install @upstash/redis
Next up, we'll want to populate our .env.local
file with the following keys, of which the Redis token can be found in your Upstash console, the Replicate API token here under your account, and your site URL would be wherever you deploy it, so in this case it would be the Vercel deployment endpoint.
SITE_URL=https://your-project-url.vercel.app
REPLICATE_API_TOKEN=
UPSTASH_REDIS_REST_URL=
UPSTASH_REDIS_REST_TOKEN=
SITE_URL=https://your-project-url.vercel.app
REPLICATE_API_TOKEN=
UPSTASH_REDIS_REST_URL=
UPSTASH_REDIS_REST_TOKEN=
Setting up the frontend form
To begin with, we'll need a form that handles the form, polling and displaying of completed images.
Restore Image Form Creation
File: pages/index.tsx
import { MouseEvent, RefObject, useRef, useState } from "react";
import Head from "next/head";
import useInterval from "../hooks/useInterval";
export default function Home() {
const [restoring, setRestoring] = useState<boolean>(false);
const [messageId, setMessageId] = useState<string | null>(null);
const [prediction, setPrediction] = useState<any>({});
const [outputImageUrl, setOutputImageUrl] = useState<string | null>(null);
const imageUrlRef: RefObject<HTMLInputElement> = useRef(null);
const hrRef: RefObject<HTMLInputElement> = useRef(null);
const scratchRef: RefObject<HTMLInputElement> = useRef(null);
useInterval(
async () => {
await fetch(`/api/poll?id=${messageId}`)
.then((res: any) => res.json())
.then((data: any) => {
if (!data.output) {
return;
}
setRestoring(false);
setMessageId(null);
setOutputImageUrl(data.output);
})
.catch((err: any) => console.error(err));
},
messageId ? 1000 : null,
);
async function restoreImage(e: any) {
e.preventDefault();
setRestoring(true);
await fetch("/api/create", {
method: "POST",
body: JSON.stringify({
image_url: imageUrlRef.current?.value,
is_hr: hrRef.current?.value,
has_scratches: scratchRef.current?.value,
}),
headers: { "Content-Type": "application/json" },
})
.then((res: Response) => res.json())
.then((data: any) => {
setMessageId(data.data.id);
setPrediction(data.data);
})
.catch((err: Error) => console.error(err));
}
async function cancel(e: MouseEvent<HTMLButtonElement>) {
e.preventDefault();
await fetch("/api/cancel", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ cancel_url: prediction.urls.cancel }),
})
.then((res: Response) => res.json())
.then((data: any) => {
setMessageId(null);
setPrediction({});
setRestoring(false);
})
.catch((err: Error) => console.error(err));
}
return (
<>
<Head>
<title>PhotoRescue</title>
<meta
name="description"
content="A simple Next.js application that utilizes Replicate to restore old photos."
/>
<meta name="viewport" content="width=device-width, initial-scale=1" />
<link rel="icon" href="/favicon.ico" />
</Head>
<main>
<div className="my-16 flex flex-col items-center justify-center md:my-32">
<h1 className="text-5xl font-black">PhotoRescue</h1>
<p className="mt-4">Restore your old photos to their former glory.</p>
{outputImageUrl && (
<div className="flex flex-col items-center justify-center">
<img
src={outputImageUrl}
alt="Restored Image"
className="mt-8 h-auto w-72"
/>
<button
type="button"
onClick={() => setOutputImageUrl(null)}
className="mt-8 inline-flex items-center rounded-full border border-transparent bg-gray-900 px-6 py-2.5 text-sm font-medium text-white shadow-sm hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-gray-600 focus:ring-offset-2 disabled:opacity-50"
>
Start Again
</button>
</div>
)}
{!outputImageUrl && (
<form
onSubmit={restoreImage}
className="mt-10 flex w-full max-w-lg flex-col items-center"
>
<div className="w-full space-y-4">
<div>
<label htmlFor="image_url" className="text-sm font-semibold">
Image URL
</label>
<input
name="image_url"
id="image_url"
type="text"
defaultValue="https://replicate.delivery/mgxm/b033ff07-1d2e-4768-a137-6c16b5ed4bed/d_1.png"
placeholder="https://example.com/image.png"
className="mt-0.5 block w-full rounded-md border border-gray-300 p-2 shadow-sm focus:border-gray-500 focus:ring-gray-500"
ref={imageUrlRef}
required
/>
</div>
<div className="max-w-lg space-y-4">
<div className="relative flex items-start">
<div className="flex h-5 items-center">
<input
name="is_hr"
id="is_hr"
type="checkbox"
className="h-4 w-4 rounded border-gray-300 text-gray-900 focus:ring-gray-500"
ref={hrRef}
/>
</div>
<div className="ml-3 text-sm">
<label
htmlFor="is_hr"
className="font-medium text-gray-900"
>
Is High Resolution?
</label>
<p className="text-gray-500">
Check this if the input image is a high resolution
photo.
</p>
</div>
</div>
<div className="relative flex items-start">
<div className="flex h-5 items-center">
<input
name="is_scratched"
id="is_scratched"
type="checkbox"
className="h-4 w-4 rounded border-gray-300 text-gray-900 focus:ring-gray-500"
ref={scratchRef}
defaultChecked={true}
/>
</div>
<div className="ml-3 text-sm">
<label
htmlFor="is_scratched"
className="font-medium text-gray-900"
>
Has Scratches?
</label>
<p className="text-gray-500">
Check this if the input image has visible scratches over
it.
</p>
</div>
</div>
</div>
</div>
<div className="mt-6 flex gap-2">
<button
type="submit"
disabled={restoring}
className="inline-flex items-center rounded-full border border-transparent bg-gray-900 px-6 py-2.5 text-sm font-medium text-white shadow-sm hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-gray-600 focus:ring-offset-2 disabled:opacity-50"
>
{restoring ? "Restoring..." : "Restore"}
</button>
{restoring && prediction && (
<button
type="button"
onClick={cancel}
className="inline-flex items-center rounded-full border border-gray-900 bg-white px-6 py-2.5 text-sm font-medium text-gray-900 shadow-sm hover:bg-gray-100 focus:outline-none focus:ring-2 focus:ring-gray-600 focus:ring-offset-2"
>
Cancel
</button>
)}
</div>
</form>
)}
</div>
</main>
</>
);
}
import { MouseEvent, RefObject, useRef, useState } from "react";
import Head from "next/head";
import useInterval from "../hooks/useInterval";
export default function Home() {
const [restoring, setRestoring] = useState<boolean>(false);
const [messageId, setMessageId] = useState<string | null>(null);
const [prediction, setPrediction] = useState<any>({});
const [outputImageUrl, setOutputImageUrl] = useState<string | null>(null);
const imageUrlRef: RefObject<HTMLInputElement> = useRef(null);
const hrRef: RefObject<HTMLInputElement> = useRef(null);
const scratchRef: RefObject<HTMLInputElement> = useRef(null);
useInterval(
async () => {
await fetch(`/api/poll?id=${messageId}`)
.then((res: any) => res.json())
.then((data: any) => {
if (!data.output) {
return;
}
setRestoring(false);
setMessageId(null);
setOutputImageUrl(data.output);
})
.catch((err: any) => console.error(err));
},
messageId ? 1000 : null,
);
async function restoreImage(e: any) {
e.preventDefault();
setRestoring(true);
await fetch("/api/create", {
method: "POST",
body: JSON.stringify({
image_url: imageUrlRef.current?.value,
is_hr: hrRef.current?.value,
has_scratches: scratchRef.current?.value,
}),
headers: { "Content-Type": "application/json" },
})
.then((res: Response) => res.json())
.then((data: any) => {
setMessageId(data.data.id);
setPrediction(data.data);
})
.catch((err: Error) => console.error(err));
}
async function cancel(e: MouseEvent<HTMLButtonElement>) {
e.preventDefault();
await fetch("/api/cancel", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ cancel_url: prediction.urls.cancel }),
})
.then((res: Response) => res.json())
.then((data: any) => {
setMessageId(null);
setPrediction({});
setRestoring(false);
})
.catch((err: Error) => console.error(err));
}
return (
<>
<Head>
<title>PhotoRescue</title>
<meta
name="description"
content="A simple Next.js application that utilizes Replicate to restore old photos."
/>
<meta name="viewport" content="width=device-width, initial-scale=1" />
<link rel="icon" href="/favicon.ico" />
</Head>
<main>
<div className="my-16 flex flex-col items-center justify-center md:my-32">
<h1 className="text-5xl font-black">PhotoRescue</h1>
<p className="mt-4">Restore your old photos to their former glory.</p>
{outputImageUrl && (
<div className="flex flex-col items-center justify-center">
<img
src={outputImageUrl}
alt="Restored Image"
className="mt-8 h-auto w-72"
/>
<button
type="button"
onClick={() => setOutputImageUrl(null)}
className="mt-8 inline-flex items-center rounded-full border border-transparent bg-gray-900 px-6 py-2.5 text-sm font-medium text-white shadow-sm hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-gray-600 focus:ring-offset-2 disabled:opacity-50"
>
Start Again
</button>
</div>
)}
{!outputImageUrl && (
<form
onSubmit={restoreImage}
className="mt-10 flex w-full max-w-lg flex-col items-center"
>
<div className="w-full space-y-4">
<div>
<label htmlFor="image_url" className="text-sm font-semibold">
Image URL
</label>
<input
name="image_url"
id="image_url"
type="text"
defaultValue="https://replicate.delivery/mgxm/b033ff07-1d2e-4768-a137-6c16b5ed4bed/d_1.png"
placeholder="https://example.com/image.png"
className="mt-0.5 block w-full rounded-md border border-gray-300 p-2 shadow-sm focus:border-gray-500 focus:ring-gray-500"
ref={imageUrlRef}
required
/>
</div>
<div className="max-w-lg space-y-4">
<div className="relative flex items-start">
<div className="flex h-5 items-center">
<input
name="is_hr"
id="is_hr"
type="checkbox"
className="h-4 w-4 rounded border-gray-300 text-gray-900 focus:ring-gray-500"
ref={hrRef}
/>
</div>
<div className="ml-3 text-sm">
<label
htmlFor="is_hr"
className="font-medium text-gray-900"
>
Is High Resolution?
</label>
<p className="text-gray-500">
Check this if the input image is a high resolution
photo.
</p>
</div>
</div>
<div className="relative flex items-start">
<div className="flex h-5 items-center">
<input
name="is_scratched"
id="is_scratched"
type="checkbox"
className="h-4 w-4 rounded border-gray-300 text-gray-900 focus:ring-gray-500"
ref={scratchRef}
defaultChecked={true}
/>
</div>
<div className="ml-3 text-sm">
<label
htmlFor="is_scratched"
className="font-medium text-gray-900"
>
Has Scratches?
</label>
<p className="text-gray-500">
Check this if the input image has visible scratches over
it.
</p>
</div>
</div>
</div>
</div>
<div className="mt-6 flex gap-2">
<button
type="submit"
disabled={restoring}
className="inline-flex items-center rounded-full border border-transparent bg-gray-900 px-6 py-2.5 text-sm font-medium text-white shadow-sm hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-gray-600 focus:ring-offset-2 disabled:opacity-50"
>
{restoring ? "Restoring..." : "Restore"}
</button>
{restoring && prediction && (
<button
type="button"
onClick={cancel}
className="inline-flex items-center rounded-full border border-gray-900 bg-white px-6 py-2.5 text-sm font-medium text-gray-900 shadow-sm hover:bg-gray-100 focus:outline-none focus:ring-2 focus:ring-gray-600 focus:ring-offset-2"
>
Cancel
</button>
)}
</div>
</form>
)}
</div>
</main>
</>
);
}
By default, this component displays a form allowing the user to enter the image URL of the image that they'd like to restore, and a couple of options to go alongside it such as whether the image is high resolution, or if the image has scratches that need to be removed. Once the user fills in this information, and submits the form, it sends a POST request to /api/create
along with the form data.
Once this request has been sent to the API and a response has been received with the prediction information that has been returned, the component enters into a polling state that checks sends a GET request to /api/poll
once per second in order to check if the prediction has completed yet. Once a polling request returns a successful response, indicating that Replicate has sent a request to our callback endpoint, we'll now have access to the prediction output.
Whilst the polling is ongoing, the form displays a button with the option to cancel the prediction. Once pressed, this sends a POST request to /api/cancel
with the cancel_url
from the prediction data that we received upon initial creation.
The polling implementation utilises a custom hook to be located in hooks/useInterval.ts
which allows us to easily and seamlessly work with React's component lifestyle, and provide a more convenient way to handle intervals with callbacks within any given React component. You can read more about this hook here and here should you want to learn more about it and in greater detail.
import { useEffect, useRef } from "react";
function useInterval(callback: () => void, delay: number | null) {
const savedCallback = useRef(callback);
useEffect(() => {
savedCallback.current = callback;
}, [callback]);
useEffect(() => {
if (!delay && delay !== 0) {
return;
}
const id = setInterval(() => savedCallback.current(), delay);
return () => clearInterval(id);
}, [delay]);
}
export default useInterval;
import { useEffect, useRef } from "react";
function useInterval(callback: () => void, delay: number | null) {
const savedCallback = useRef(callback);
useEffect(() => {
savedCallback.current = callback;
}, [callback]);
useEffect(() => {
if (!delay && delay !== 0) {
return;
}
const id = setInterval(() => savedCallback.current(), delay);
return () => clearInterval(id);
}, [delay]);
}
export default useInterval;
API Setup
The API setup, made up of a few files, is what allows us to create and cancel predictions, poll to check when predictions are complete, as well as specify the callback that Replicate will use for when the prediction is complete on their end.
Image Prediction Creation
File: pages/api/create.ts
import type { NextApiRequest, NextApiResponse } from "next";
import fetch, { Response } from "node-fetch";
import redis from "../../lib/redis";
export default async function handler(
req: NextApiRequest,
res: NextApiResponse,
) {
if (req.method !== "POST") {
return res.status(400).json({
message: `Invalid request method: ${req.method}.`,
});
}
const { image_url, is_hr, has_scratches }: any = req.body;
await fetch("https://api.replicate.com/v1/predictions", {
method: "POST",
headers: {
Authorization: `Token ${process.env.REPLICATE_API_TOKEN}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
version:
"c75db81db6cbd809d93cc3b7e7a088a351a3349c9fa02b6d393e35e0d51ba799",
input: {
image: image_url,
HR: is_hr,
with_scratch: has_scratches,
},
webhook_completed: `${process.env.SITE_URL}/api/callback`,
}),
})
.then((res: Response) => res.json())
.then(async (data: any) => {
await redis.set(data.id, data);
return res.status(202).json({ data: data });
})
.catch((error: Error) => {
return res.status(500).json({ message: error.message });
});
}
import type { NextApiRequest, NextApiResponse } from "next";
import fetch, { Response } from "node-fetch";
import redis from "../../lib/redis";
export default async function handler(
req: NextApiRequest,
res: NextApiResponse,
) {
if (req.method !== "POST") {
return res.status(400).json({
message: `Invalid request method: ${req.method}.`,
});
}
const { image_url, is_hr, has_scratches }: any = req.body;
await fetch("https://api.replicate.com/v1/predictions", {
method: "POST",
headers: {
Authorization: `Token ${process.env.REPLICATE_API_TOKEN}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
version:
"c75db81db6cbd809d93cc3b7e7a088a351a3349c9fa02b6d393e35e0d51ba799",
input: {
image: image_url,
HR: is_hr,
with_scratch: has_scratches,
},
webhook_completed: `${process.env.SITE_URL}/api/callback`,
}),
})
.then((res: Response) => res.json())
.then(async (data: any) => {
await redis.set(data.id, data);
return res.status(202).json({ data: data });
})
.catch((error: Error) => {
return res.status(500).json({ message: error.message });
});
}
For our create API endpoint, we first do a simple check to ensure that the incoming request method is a POST request, and if not we'll return a simple 400 response. We then proceed to send a POST request to Replicate with our Replicate API token. The request body consists of the parameters for the given model version
which indicates what model we are sending the request to (this is found under the "API" tab on the model you'd like to use). We also pass through the parameters associated with the model with the data from the form on the frontend.
Once the request has been sent, we use the returned prediction id
to store it in Redis, and return the prediction data to the frontend to be used in polling the Redis item until it consists of a completed prediction.
Callback
File: pages/api/callback.ts
import type { NextApiRequest, NextApiResponse } from "next";
import redis from "../../lib/redis";
export default async function handler(
req: NextApiRequest,
res: NextApiResponse,
) {
const { body }: any = req;
try {
await redis.set(body.id, body);
return res.status(200).send(body);
} catch (error) {
return res.status(500).json({ error });
}
}
import type { NextApiRequest, NextApiResponse } from "next";
import redis from "../../lib/redis";
export default async function handler(
req: NextApiRequest,
res: NextApiResponse,
) {
const { body }: any = req;
try {
await redis.set(body.id, body);
return res.status(200).send(body);
} catch (error) {
return res.status(500).json({ error });
}
}
The callback endpoint is what Replicate will send a POST request to in order to let us know that the processing of a given prediction has finished. When we receive this request, we retrieve the prediction data from the request body, and update the given Redis item with the completed prediction data.
Polling
File: pages/api/poll.ts
import type { NextApiRequest, NextApiResponse } from "next";
import redis from "../../lib/redis";
export default async function handler(
req: NextApiRequest,
res: NextApiResponse,
) {
const { id }: any = req.query;
try {
const data = await redis.get(id);
if (!data) {
return res
.status(404)
.json({ message: "Data for supplied ID not found" });
}
return res.status(200).json(data);
} catch (error: any) {
return res.status(500).json({ message: error.message });
}
}
import type { NextApiRequest, NextApiResponse } from "next";
import redis from "../../lib/redis";
export default async function handler(
req: NextApiRequest,
res: NextApiResponse,
) {
const { id }: any = req.query;
try {
const data = await redis.get(id);
if (!data) {
return res
.status(404)
.json({ message: "Data for supplied ID not found" });
}
return res.status(200).json(data);
} catch (error: any) {
return res.status(500).json({ message: error.message });
}
}
For our polling setup, we extract the id
from the request and then try to retrieve the data stored in Redis under that identifier, and if no data is found, we return a 404 response, but if there is data, we return said data as part of a 200 response.
Cancel
File: pages/api/cancel.tsx
import type { NextApiRequest, NextApiResponse } from "next";
import fetch, { Response } from "node-fetch";
export default async function handler(
req: NextApiRequest,
res: NextApiResponse,
) {
if (req.method !== "POST") {
return res.status(400).json({
message: `Invalid request method: ${req.method}.`,
});
}
const { cancel_url }: any = req.body;
await fetch(cancel_url, {
method: "POST",
headers: {
Authorization: `Token ${process.env.REPLICATE_API_TOKEN}`,
"Content-Type": "application/json",
},
})
.then((res: Response) => res.json())
.then((data: any) => {
return res.status(202).json({ data: data });
})
.catch((error: Error) => {
return res.status(500).json({ message: error.message });
});
}
import type { NextApiRequest, NextApiResponse } from "next";
import fetch, { Response } from "node-fetch";
export default async function handler(
req: NextApiRequest,
res: NextApiResponse,
) {
if (req.method !== "POST") {
return res.status(400).json({
message: `Invalid request method: ${req.method}.`,
});
}
const { cancel_url }: any = req.body;
await fetch(cancel_url, {
method: "POST",
headers: {
Authorization: `Token ${process.env.REPLICATE_API_TOKEN}`,
"Content-Type": "application/json",
},
})
.then((res: Response) => res.json())
.then((data: any) => {
return res.status(202).json({ data: data });
})
.catch((error: Error) => {
return res.status(500).json({ message: error.message });
});
}
The API endpoint for cancelling predictions which have been started is rather straightforward. We simply extract the cancel_url
which is passed through from the frontend which in itself comes from the prediction that was stored when a creation request was submitted, and we simply send a POST request to that endpoint, alongside our Replicate API token.
Libs
For our libs, we'll create a Redis client which is used in tracking
File: lib/redis.ts
import { Redis } from "@upstash/redis";
const redis = new Redis({
url: process.env.UPSTASH_REDIS_REST_URL as string,
token: process.env.UPSTASH_REDIS_REST_TOKEN as string,
});
export default redis;
import { Redis } from "@upstash/redis";
const redis = new Redis({
url: process.env.UPSTASH_REDIS_REST_URL as string,
token: process.env.UPSTASH_REDIS_REST_TOKEN as string,
});
export default redis;
This object will be used within the application to store and retrieve data whilst being polled, so that we know when the webhook completion from Replicate has happened.
Conclusion
Replicate have a variety of models available that can be used via an API. With Vercel and Upstash, it's easier than ever to utilise machine learning models and deploy usable web applciations.
If you'd like to view the complete repository, you can access it here.
Further Development
This is just a simple example of utilising a rather simple model with Replicate. By simply switching up the form parameters and the version in the API, you can easily change to another model as so long as you have your Replicate API token linked, you'll be able to consume any of the available models.
You can explore all of Replicate's available models here, and once you find one you'd like to experiment with, you can click on the "API" tab in order to view the usage of it. Here you'll also find buttons for Python, cURL, Cog and Docker, which allow you to test out the model, but it's also useful for knowing which parameters are required, and how they are sent.