Working With ML

We'll take a look at a few examples that dive into how Gradio models can be built specifically to work with your ML models.

Image Classification in Tensorflow / Keras Colab link

We'll start with the Inception Net image classifier, which we'll load using Tensorflow! Since this is an image classification model, we will use the Image input interface. We'll output a dictionary of labels and their corresponding confidence scores with the Label output interface. (The original Inception Net architecture can be found here)

import gradio as gr
import tensorflow as tf
import numpy as np
import json
from os.path import dirname, realpath, join

# Load human-readable labels for ImageNet.
current_dir = dirname(realpath(__file__))
with open(join(current_dir, "files/imagenet_labels.json")) as labels_file:
    labels = json.load(labels_file)

mobile_net = tf.keras.applications.MobileNetV2()
def image_classifier(im):
    arr = np.expand_dims(im, axis=0)
    arr = tf.keras.applications.mobilenet.preprocess_input(arr)
    prediction = mobile_net.predict(arr).flatten()
    return {labels[i]: float(prediction[i]) for i in range(1000)}

iface = gr.Interface(
    gr.inputs.Image(shape=(224, 224)), 


This code will produce the interface below. The interface gives you a way to test Inception Net by dragging and dropping images, and also allows you to use naturally modify the input image using image editing tools that appear when you click the edit button. Notice here we provided actual gradio.inputs and gradio.outputs objects to the Interface function instead of using string shortcuts. This lets us use built-in preprocessing (e.g. image resizing) and postprocessing (e.g. choosing the number of labels to display) provided by these interfaces. Finally, we use capture_session=True to ensure compatibility with TF 1.x.

Try it out in your device or run it in a colab notebook!

Add Interpretation

The above code also shows how you can add interpretation to your interface. You can use our out of the box functions for text and image interpretation or use your own interpretation functions. To use the out of the box functions just specify “default” for the interpretation parameter (Note: this only works for text/image input and label outputs).

gr.Interface(classify_image, image, label, capture_session=True, interpretation="default").launch();

Image Classification in Pytorch Colab link

Let's now wrap a very similar model, ResNet, except this time in Pytorch. We'll also use the Image to Label interface. (The original ResNet architecture can be found here

import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import json
from os.path import dirname, realpath, join

# Load human-readable labels for Resnet.
current_dir = dirname(realpath(__file__))
with open(join(current_dir, "files/imagenet_labels.json")) as labels_file:
    labels = json.load(labels_file)

model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()

def predict(inp):
  inp = Image.fromarray(inp.astype('uint8'), 'RGB')
  inp = transforms.ToTensor()(inp).unsqueeze(0)
  with torch.no_grad():
    prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
  return {labels[i]: float(prediction[i]) for i in range(1000)}

inputs = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=3)
iface = gr.Interface(fn=predict, inputs=inputs, outputs=outputs)


This code will produce the interface below.

Text Generation with Transformers (GPT-2) Colab link

Let's wrap a Text to Text interface around GPT-2, a text generation model that works on provided starter text. Click here to learn more about GPT-2 and similar language models.

import gradio as gr

def greet(name):
  return "Hello " + name + "!!"

iface = gr.Interface(fn=greet, inputs="text", outputs="text")

This code will produce the interface below. That's all that's needed!

Answering Questions with BERT-QA Colab link

What if our model takes more than one input? Let's wrap a 2-input to 1-output interface around BERT-QA, a model that can answer general questions.

import gradio as gr
import os, sys
file_folder = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(file_folder, "files"))
from bert import QA

model = QA('bert-large-uncased-whole-word-masking-finetuned-squad')
def qa_func(paragraph, question):
    return model.predict(paragraph, question)["answer"]

iface = gr.Interface(qa_func, 
        gr.inputs.Textbox(lines=7, label="Context", default="Victoria has a written constitution enacted in 1975, but based on the 1855 colonial constitution, passed by the United Kingdom Parliament as the Victoria Constitution Act 1855, which establishes the Parliament as the state's law-making body for matters coming under state responsibility. The Victorian Constitution can be amended by the Parliament of Victoria, except for certain 'entrenched' provisions that require either an absolute majority in both houses, a three-fifths majority in both houses, or the approval of the Victorian people in a referendum, depending on the provision."), 
        gr.inputs.Textbox(lines=1, label="Question", default="When did Victoria enact its constitution?"), 

As shown in the code, Gradio can wrap functions with multiple inputs or outputs, simply by taking the list ofcomponents needed. The number of input components should match the number of parameters taken by fn. The number of output components should match the number of values returned by fn. Similarly, if a model returns multiple outputs, you can pass in a list of output interfaces.

Numerical Interfaces: Titanic Survival Model Colab link

Many models have numeric or categorical inputs, which we support with a variety of interfaces. Let's wrap multiple input to label interface around a Titanic survival model.

import pandas as pd
import numpy as np
import sklearn
import gradio as gr
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import os

current_dir = os.path.dirname(os.path.realpath(__file__))
data = pd.read_csv(os.path.join(current_dir, 'files/titanic.csv'))

def encode_age(df):
    df.Age = df.Age.fillna(-0.5)
    bins = (-1, 0, 5, 12, 18, 25, 35, 60, 120)
    categories = pd.cut(df.Age, bins, labels=False)
    df.Age = categories
    return df

def encode_fare(df):
    df.Fare = df.Fare.fillna(-0.5)
    bins = (-1, 0, 8, 15, 31, 1000)
    categories = pd.cut(df.Fare, bins, labels=False)
    df.Fare = categories
    return df

def encode_df(df):
    df = encode_age(df)
    df = encode_fare(df)
    sex_mapping = {"male": 0, "female": 1}
    df = df.replace({'Sex': sex_mapping})
    embark_mapping = {"S": 1, "C": 2, "Q": 3}
    df = df.replace({'Embarked': embark_mapping})
    df.Embarked = df.Embarked.fillna(0)
    df["Company"] = 0
    df.loc[(df["SibSp"] > 0), "Company"] = 1
    df.loc[(df["Parch"] > 0), "Company"] = 2
    df.loc[(df["SibSp"] > 0) & (df["Parch"] > 0), "Company"] = 3
    df = df[["PassengerId", "Pclass", "Sex", "Age", "Fare", "Embarked", "Company", "Survived"]]
    return df

train = encode_df(data)

X_all = train.drop(['Survived', 'PassengerId'], axis=1)
y_all = train['Survived']

num_test = 0.20
X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=num_test, random_state=23)

clf = RandomForestClassifier(), y_train)
predictions = clf.predict(X_test)

def predict_survival(passenger_class, is_male, age, company, fare, embark_point):
    df = pd.DataFrame.from_dict({
        'Pclass': [passenger_class + 1], 
        'Sex': [0 if is_male else 1], 
        'Age': [age],
        'Company': [(1 if "Sibling" in company else 0) + (2 if "Child" in company else 0)],
        'Fare': [fare],
        'Embarked': [embark_point + 1]
    df = encode_age(df)
    df = encode_fare(df)
    pred = clf.predict_proba(df)[0]
    return {'Perishes': pred[0], 'Survives': pred[1]}

iface = gr.Interface(
        gr.inputs.Dropdown(["first", "second", "third"], type="index"),
        gr.inputs.Slider(0, 80),
        gr.inputs.CheckboxGroup(["Sibling", "Child"], label="Travelling with (select all)"),
        gr.inputs.Radio(["S", "C", "Q"], type="index"),
        ["first", True, 30, [], 50, "S"],
        ["second", False, 40, ["Sibling", "Child"], 10, "Q"],
        ["third", True, 30, ["Child"], 20, "S"],


This code will produce the interface below.