Generate spatial relationship dataset
In [ ]:
Copied!
import os
import jsonlines
import wandb
import weave
import os
import jsonlines
import wandb
import weave
In [ ]:
Copied!
wandb.init(project="2d-spatial-relationship")
weave.init(project_name="2d-spatial-relationship")
wandb.init(project="2d-spatial-relationship")
weave.init(project_name="2d-spatial-relationship")
In [ ]:
Copied!
mscoco_classes = [
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]
mscoco_classes = [
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]
In [ ]:
Copied!
def compose_prompt(entity_1: str, entity_2: str, relationship: str):
numeracy_entity_1 = "an" if entity_1[0] in "aeiou" else "a"
numeracy_entity_2 = "an" if entity_2[0] in "aeiou" else "a"
return f"{numeracy_entity_1} {entity_1} {relationship} {numeracy_entity_2} {entity_2}"
spatial_relationship_rows = []
relationships = [
"near",
"next to",
"on side of",
"side of",
"on the right of",
"on the left of",
"on the bottom of",
"on the top of"
]
table = wandb.Table(columns=["prompt", "entity_1", "entity_2", "relationship"])
for entity_1 in mscoco_classes:
for entity_2 in mscoco_classes:
if entity_1 == entity_2:
continue
for relationship in relationships:
row = {
"prompt": compose_prompt(entity_1, entity_2, relationship),
"entity_1": entity_1,
"entity_2": entity_2,
"relationship": relationship,
}
spatial_relationship_rows.append(row)
table.add_data(row["prompt"], row["entity_1"], row["entity_2"], row["relationship"])
def compose_prompt(entity_1: str, entity_2: str, relationship: str):
numeracy_entity_1 = "an" if entity_1[0] in "aeiou" else "a"
numeracy_entity_2 = "an" if entity_2[0] in "aeiou" else "a"
return f"{numeracy_entity_1} {entity_1} {relationship} {numeracy_entity_2} {entity_2}"
spatial_relationship_rows = []
relationships = [
"near",
"next to",
"on side of",
"side of",
"on the right of",
"on the left of",
"on the bottom of",
"on the top of"
]
table = wandb.Table(columns=["prompt", "entity_1", "entity_2", "relationship"])
for entity_1 in mscoco_classes:
for entity_2 in mscoco_classes:
if entity_1 == entity_2:
continue
for relationship in relationships:
row = {
"prompt": compose_prompt(entity_1, entity_2, relationship),
"entity_1": entity_1,
"entity_2": entity_2,
"relationship": relationship,
}
spatial_relationship_rows.append(row)
table.add_data(row["prompt"], row["entity_1"], row["entity_2"], row["relationship"])
In [ ]:
Copied!
with jsonlines.open(os.path.join("dataset.jsonl"), mode="w") as writer:
writer.write(spatial_relationship_rows)
dataset_description = """A dataset of prompts for evaluation 2D spatial relationships between objects in images.
The dataset is generated using the vocabulary of objects from the [MSCOCO](https://cocodataset.org) dataset.
The idea for this dataset is inspired by the
[T2I-Compbench 2D Spatial Relationships dataset](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/huangky_connect_hku_hk/Er_BhrcMwGREht6gnKGIErMBx8H8yRXLDfWgWQwKaObQ4w?e=YzT5wG).
"""
artifact = wandb.Artifact(
name="2d-spatial-prompts-mscoco",
type="dataset",
metadata={
"format": "jsonl",
"description": dataset_description,
}
)
artifact.add_file(local_path=os.path.join("dataset.jsonl"))
wandb.log_artifact(artifact)
wandb.log({"dataset/2d_spatial_prompts": table})
with jsonlines.open(os.path.join("dataset.jsonl"), mode="w") as writer:
writer.write(spatial_relationship_rows)
dataset_description = """A dataset of prompts for evaluation 2D spatial relationships between objects in images.
The dataset is generated using the vocabulary of objects from the [MSCOCO](https://cocodataset.org) dataset.
The idea for this dataset is inspired by the
[T2I-Compbench 2D Spatial Relationships dataset](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/huangky_connect_hku_hk/Er_BhrcMwGREht6gnKGIErMBx8H8yRXLDfWgWQwKaObQ4w?e=YzT5wG).
"""
artifact = wandb.Artifact(
name="2d-spatial-prompts-mscoco",
type="dataset",
metadata={
"format": "jsonl",
"description": dataset_description,
}
)
artifact.add_file(local_path=os.path.join("dataset.jsonl"))
wandb.log_artifact(artifact)
wandb.log({"dataset/2d_spatial_prompts": table})
In [ ]:
Copied!
spatial_relationship_dataset = weave.Dataset(
name="2d-spatial-prompts-mscoco",
rows=spatial_relationship_rows,
description=dataset_description,
)
weave.publish(spatial_relationship_dataset)
spatial_relationship_dataset = weave.Dataset(
name="2d-spatial-prompts-mscoco",
rows=spatial_relationship_rows,
description=dataset_description,
)
weave.publish(spatial_relationship_dataset)
In [ ]:
Copied!
wandb.finish()
os.remove("dataset.jsonl")
wandb.finish()
os.remove("dataset.jsonl")