Skip to content

Spatial Relationship Metrics

This module aims to implement the Spatial relationship metric described in section 3.2 of T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation.

Using an object-detection model for spatial relationship evaluation as proposed in T2I-CompBench
Weave gives us a holistic view of the evaluations to drill into individual ouputs and scores.

Example

Step 1: Generate evaluation dataset

Generate an evaluation dataset using the MSCOCO object vocabulary and publish it as a Weave Dataset. You can follow this notebook to learn about the porocess.

Step 2: Evaluate

import asyncio
import weave

from hemm.models import DiffusersModel
from hemm.metrics.spatial_relationship import SpatialRelationshipMetric2D
from hemm.metrics.image_quality import LPIPSMetric, PSNRMetric, SSIMMetric

# Initialize Weave
weave.init(project_name="image-quality-leaderboard")

# Initialize the diffusion model to be evaluated as a `weave.Model`
model = DiffusersModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4")

# Define the judge model for 2d spatial relationship metric
judge = DETRSpatialRelationShipJudge(
    model_address=detr_model_address, revision=detr_revision
)

# Add 2d spatial relationship Metric to the evaluation pipeline
metric = SpatialRelationshipMetric2D(judge=judge, name="2d_spatial_relationship_score")

# Evaluate!
dataset = weave.ref("2d-spatial-prompts-mscoco:v0").get()
evaluation = weave.Evaluation(dataset=dataset, scorers=[metric])
summary = asyncio.run(evaluation.evaluate(model))

Metrics

SpatialRelationshipMetric2D

Bases: Scorer

Spatial relationship metric for image generation as proposed in Section 4.2 from the paper T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation.

Sample usage

import asyncio
import weave

from hemm.models import DiffusersModel
from hemm.metrics.spatial_relationship.judges import DETRSpatialRelationShipJudge

# Initialize Weave
weave.init(project_name="image-quality-leaderboard")

# Initialize the diffusion model to be evaluated as a `weave.Model`
model = DiffusersModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4")

# Define the judge model for 2d spatial relationship metric
judge = DETRSpatialRelationShipJudge(
    model_address=detr_model_address, revision=detr_revision
)

# Define 2d spatial relationship Metric to the evaluation pipeline
metric = SpatialRelationshipMetric2D(judge=judge, name="2d_spatial_relationship_score")

# Evaluate!
dataset = weave.ref("2d-spatial-t2i_compbench_spatial_prompts-mscoco:v0").get()
evaluation = weave.Evaluation(dataset=dataset, scorers=[metric])
asyncio.run(evaluation.evaluate(model))

Parameters:

Name Type Description Default
judge Union[Model, DETRSpatialRelationShipJudge]

The judge model to predict the bounding boxes from the generated image.

required
iou_threshold Optional[float]

The IoU threshold for the spatial relationship.

required
distance_threshold Optional[float]

The distance threshold for the spatial relationship.

required
name Optional[str]

The name of the metric.

required
Source code in hemm/metrics/spatial_relationship/spatial_relationship_2d.py
class SpatialRelationshipMetric2D(weave.Scorer):
    """Spatial relationship metric for image generation as proposed in Section 4.2 from the paper
    [T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation](https://arxiv.org/pdf/2307.06350).

    !!! example "Sample usage"
        ```python
        import asyncio
        import weave

        from hemm.models import DiffusersModel
        from hemm.metrics.spatial_relationship.judges import DETRSpatialRelationShipJudge

        # Initialize Weave
        weave.init(project_name="image-quality-leaderboard")

        # Initialize the diffusion model to be evaluated as a `weave.Model`
        model = DiffusersModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4")

        # Define the judge model for 2d spatial relationship metric
        judge = DETRSpatialRelationShipJudge(
            model_address=detr_model_address, revision=detr_revision
        )

        # Define 2d spatial relationship Metric to the evaluation pipeline
        metric = SpatialRelationshipMetric2D(judge=judge, name="2d_spatial_relationship_score")

        # Evaluate!
        dataset = weave.ref("2d-spatial-t2i_compbench_spatial_prompts-mscoco:v0").get()
        evaluation = weave.Evaluation(dataset=dataset, scorers=[metric])
        asyncio.run(evaluation.evaluate(model))
        ```

    Args:
        judge (Union[weave.Model, DETRSpatialRelationShipJudge]): The judge model to predict
            the bounding boxes from the generated image.
        iou_threshold (Optional[float], optional): The IoU threshold for the spatial relationship.
        distance_threshold (Optional[float], optional): The distance threshold for the spatial relationship.
        name (Optional[str], optional): The name of the metric.
    """

    judge: weave.Model
    iou_threshold: float = 0.1
    distance_threshold: float = 150

    @weave.op()
    def compose_judgement(
        self,
        prompt: str,
        image: Image.Image,
        entity_1: str,
        entity_2: str,
        relationship: str,
        boxes: List[BoundingBox],
    ) -> Dict[str, Any]:
        """Compose the judgement based on the response and the predicted bounding boxes.

        Args:
            prompt (str): The prompt using which the image was generated.
            image (Image.Image): The input image.
            entity_1 (str): First entity.
            entity_2 (str): Second entity.
            relationship (str): Relationship between the entities.
            boxes (List[BoundingBox]): The predicted bounding boxes.

        Returns:
            Dict[str, Any]: The comprehensive spatial relationship judgement.
        """
        _ = prompt

        # Determine presence of entities in the judgement
        judgement = {
            "entity_1_present": False,
            "entity_2_present": False,
        }
        entity_1_box: BoundingBox = None
        entity_2_box: BoundingBox = None
        annotated_image = image
        for box in boxes:
            if box.label == entity_1:
                judgement["entity_1_present"] = True
                entity_1_box = box
            elif box.label == entity_2:
                judgement["entity_2_present"] = True
                entity_2_box = box
            annotated_image = annotate_with_bounding_box(annotated_image, box)

        judgement["score"] = 0.0
        # assign score based on the spatial relationship inferred from the judgement
        if judgement["entity_1_present"] and judgement["entity_2_present"]:
            center_distance_x = abs(
                entity_1_box.box_coordinates_center.x
                - entity_2_box.box_coordinates_center.x
            )
            center_distance_y = abs(
                entity_1_box.box_coordinates_center.y
                - entity_2_box.box_coordinates_center.y
            )
            iou = get_iou(entity_1_box, entity_2_box)
            score = 0.0
            if relationship in ["near", "next to", "on side of", "side of"]:
                if (
                    abs(center_distance_x) < self.distance_threshold
                    or abs(center_distance_y) < self.distance_threshold
                ):
                    score = 1.0
                else:
                    score = self.distance_threshold / max(
                        abs(center_distance_x), abs(center_distance_y)
                    )
            elif relationship == "on the right of":
                if center_distance_x < 0:
                    if (
                        abs(center_distance_x) > abs(center_distance_y)
                        and iou < self.iou_threshold
                    ):
                        score = 1.0
                    elif (
                        abs(center_distance_x) > abs(center_distance_y)
                        and iou >= self.iou_threshold
                    ):
                        score = self.iou_threshold / iou
            elif relationship == "on the left of":
                if center_distance_x > 0:
                    if (
                        abs(center_distance_x) > abs(center_distance_y)
                        and iou < self.iou_threshold
                    ):
                        score = 1.0
                    elif (
                        abs(center_distance_x) > abs(center_distance_y)
                        and iou >= self.iou_threshold
                    ):
                        score = self.iou_threshold / iou
                else:
                    score = 0.0
            elif relationship == "on the bottom of":
                if center_distance_y < 0:
                    if (
                        abs(center_distance_y) > abs(center_distance_x)
                        and iou < self.iou_threshold
                    ):
                        score = 1
                    elif (
                        abs(center_distance_y) > abs(center_distance_x)
                        and iou >= self.iou_threshold
                    ):
                        score = self.iou_threshold / iou
            elif relationship == "on the top of":
                if center_distance_y > 0:
                    if (
                        abs(center_distance_y) > abs(center_distance_x)
                        and iou < self.iou_threshold
                    ):
                        score = 1
                    elif (
                        abs(center_distance_y) > abs(center_distance_x)
                        and iou >= self.iou_threshold
                    ):
                        score = self.iou_threshold / iou
            judgement["score"] = score

        return {**judgement, **{"judge_annotated_image": annotated_image}}

    @weave.op()
    def score(
        self,
        prompt: str,
        entity_1: str,
        entity_2: str,
        relationship: str,
        model_output: Dict[str, Any],
    ) -> Dict[str, Union[bool, float, int]]:
        """Calculate the spatial relationship score for the given prompt and model output.

        Args:
            prompt (str): The prompt for the model.
            entity_1 (str): The first entity in the spatial relationship.
            entity_2 (str): The second entity in the spatial relationship.
            relationship (str): The spatial relationship between the two entities.
            model_output (Dict[str, Any]): The output from the model.

        Returns:
            Dict[str, Union[bool, float, int]]: The comprehensive spatial relationship judgement.
        """
        _ = prompt

        image = model_output["image"]
        boxes: List[BoundingBox] = self.judge.predict(image)
        judgement = self.compose_judgement(
            prompt, image, entity_1, entity_2, relationship, boxes
        )
        return {self.name: judgement["score"]}

compose_judgement(prompt, image, entity_1, entity_2, relationship, boxes)

Compose the judgement based on the response and the predicted bounding boxes.

Parameters:

Name Type Description Default
prompt str

The prompt using which the image was generated.

required
image Image

The input image.

required
entity_1 str

First entity.

required
entity_2 str

Second entity.

required
relationship str

Relationship between the entities.

required
boxes List[BoundingBox]

The predicted bounding boxes.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: The comprehensive spatial relationship judgement.

Source code in hemm/metrics/spatial_relationship/spatial_relationship_2d.py
@weave.op()
def compose_judgement(
    self,
    prompt: str,
    image: Image.Image,
    entity_1: str,
    entity_2: str,
    relationship: str,
    boxes: List[BoundingBox],
) -> Dict[str, Any]:
    """Compose the judgement based on the response and the predicted bounding boxes.

    Args:
        prompt (str): The prompt using which the image was generated.
        image (Image.Image): The input image.
        entity_1 (str): First entity.
        entity_2 (str): Second entity.
        relationship (str): Relationship between the entities.
        boxes (List[BoundingBox]): The predicted bounding boxes.

    Returns:
        Dict[str, Any]: The comprehensive spatial relationship judgement.
    """
    _ = prompt

    # Determine presence of entities in the judgement
    judgement = {
        "entity_1_present": False,
        "entity_2_present": False,
    }
    entity_1_box: BoundingBox = None
    entity_2_box: BoundingBox = None
    annotated_image = image
    for box in boxes:
        if box.label == entity_1:
            judgement["entity_1_present"] = True
            entity_1_box = box
        elif box.label == entity_2:
            judgement["entity_2_present"] = True
            entity_2_box = box
        annotated_image = annotate_with_bounding_box(annotated_image, box)

    judgement["score"] = 0.0
    # assign score based on the spatial relationship inferred from the judgement
    if judgement["entity_1_present"] and judgement["entity_2_present"]:
        center_distance_x = abs(
            entity_1_box.box_coordinates_center.x
            - entity_2_box.box_coordinates_center.x
        )
        center_distance_y = abs(
            entity_1_box.box_coordinates_center.y
            - entity_2_box.box_coordinates_center.y
        )
        iou = get_iou(entity_1_box, entity_2_box)
        score = 0.0
        if relationship in ["near", "next to", "on side of", "side of"]:
            if (
                abs(center_distance_x) < self.distance_threshold
                or abs(center_distance_y) < self.distance_threshold
            ):
                score = 1.0
            else:
                score = self.distance_threshold / max(
                    abs(center_distance_x), abs(center_distance_y)
                )
        elif relationship == "on the right of":
            if center_distance_x < 0:
                if (
                    abs(center_distance_x) > abs(center_distance_y)
                    and iou < self.iou_threshold
                ):
                    score = 1.0
                elif (
                    abs(center_distance_x) > abs(center_distance_y)
                    and iou >= self.iou_threshold
                ):
                    score = self.iou_threshold / iou
        elif relationship == "on the left of":
            if center_distance_x > 0:
                if (
                    abs(center_distance_x) > abs(center_distance_y)
                    and iou < self.iou_threshold
                ):
                    score = 1.0
                elif (
                    abs(center_distance_x) > abs(center_distance_y)
                    and iou >= self.iou_threshold
                ):
                    score = self.iou_threshold / iou
            else:
                score = 0.0
        elif relationship == "on the bottom of":
            if center_distance_y < 0:
                if (
                    abs(center_distance_y) > abs(center_distance_x)
                    and iou < self.iou_threshold
                ):
                    score = 1
                elif (
                    abs(center_distance_y) > abs(center_distance_x)
                    and iou >= self.iou_threshold
                ):
                    score = self.iou_threshold / iou
        elif relationship == "on the top of":
            if center_distance_y > 0:
                if (
                    abs(center_distance_y) > abs(center_distance_x)
                    and iou < self.iou_threshold
                ):
                    score = 1
                elif (
                    abs(center_distance_y) > abs(center_distance_x)
                    and iou >= self.iou_threshold
                ):
                    score = self.iou_threshold / iou
        judgement["score"] = score

    return {**judgement, **{"judge_annotated_image": annotated_image}}

score(prompt, entity_1, entity_2, relationship, model_output)

Calculate the spatial relationship score for the given prompt and model output.

Parameters:

Name Type Description Default
prompt str

The prompt for the model.

required
entity_1 str

The first entity in the spatial relationship.

required
entity_2 str

The second entity in the spatial relationship.

required
relationship str

The spatial relationship between the two entities.

required
model_output Dict[str, Any]

The output from the model.

required

Returns:

Type Description
Dict[str, Union[bool, float, int]]

Dict[str, Union[bool, float, int]]: The comprehensive spatial relationship judgement.

Source code in hemm/metrics/spatial_relationship/spatial_relationship_2d.py
@weave.op()
def score(
    self,
    prompt: str,
    entity_1: str,
    entity_2: str,
    relationship: str,
    model_output: Dict[str, Any],
) -> Dict[str, Union[bool, float, int]]:
    """Calculate the spatial relationship score for the given prompt and model output.

    Args:
        prompt (str): The prompt for the model.
        entity_1 (str): The first entity in the spatial relationship.
        entity_2 (str): The second entity in the spatial relationship.
        relationship (str): The spatial relationship between the two entities.
        model_output (Dict[str, Any]): The output from the model.

    Returns:
        Dict[str, Union[bool, float, int]]: The comprehensive spatial relationship judgement.
    """
    _ = prompt

    image = model_output["image"]
    boxes: List[BoundingBox] = self.judge.predict(image)
    judgement = self.compose_judgement(
        prompt, image, entity_1, entity_2, relationship, boxes
    )
    return {self.name: judgement["score"]}

Judges

DETRSpatialRelationShipJudge

Bases: Model

DETR spatial relationship judge model for 2D images.

Parameters:

Name Type Description Default
model_address str

The address of the model to use.

'facebook/detr-resnet-50'
revision str

The revision of the model to use.

'no_timm'
name str

The name of the judge model

'detr_spatial_relationship_judge'
Source code in hemm/metrics/spatial_relationship/judges/detr.py
class DETRSpatialRelationShipJudge(weave.Model):
    """[DETR](https://huggingface.co/docs/transformers/en/model_doc/detr) spatial relationship judge model for 2D images.

    Args:
        model_address (str, optional): The address of the model to use.
        revision (str, optional): The revision of the model to use.
        name (str, optional): The name of the judge model
    """

    model_address: str
    revision: str
    name: str
    _feature_extractor: DetrImageProcessor = None
    _object_detection_model: DetrForObjectDetection = None

    def __init__(
        self,
        model_address: str = "facebook/detr-resnet-50",
        revision: str = "no_timm",
        name: str = "detr_spatial_relationship_judge",
    ):
        super().__init__(model_address=model_address, revision=revision, name=name)
        self._feature_extractor = DetrImageProcessor.from_pretrained(
            self.model_address, revision=self.revision
        )
        self._object_detection_model = DetrForObjectDetection.from_pretrained(
            self.model_address, revision=self.revision
        )

    @weave.op()
    def predict(self, image: Image.Image) -> List[BoundingBox]:
        """Predict the bounding boxes from the input image.

        Args:
            image (Image.Image): The input image.

        Returns:
            List[BoundingBox]: The predicted bounding boxes.
        """
        encoding = self._feature_extractor(image, return_tensors="pt")
        outputs = self._object_detection_model(**encoding)
        target_sizes = torch.tensor([image.size[::-1]])
        results = self._feature_extractor.post_process_object_detection(
            outputs, target_sizes=target_sizes, threshold=0.9
        )[0]
        bboxes = []
        for score, label, box in zip(
            results["scores"], results["labels"], results["boxes"]
        ):
            xmin, ymin, xmax, ymax = box.tolist()
            bboxes.append(
                BoundingBox(
                    box_coordinates_min=CartesianCoordinate2D(x=xmin, y=ymin),
                    box_coordinates_max=CartesianCoordinate2D(x=xmax, y=ymax),
                    box_coordinates_center=CartesianCoordinate2D(
                        x=(xmin + xmax) / 2, y=(ymin + ymax) / 2
                    ),
                    label=self._object_detection_model.config.id2label[label.item()],
                    score=score.item(),
                )
            )
        return bboxes

predict(image)

Predict the bounding boxes from the input image.

Parameters:

Name Type Description Default
image Image

The input image.

required

Returns:

Type Description
List[BoundingBox]

List[BoundingBox]: The predicted bounding boxes.

Source code in hemm/metrics/spatial_relationship/judges/detr.py
@weave.op()
def predict(self, image: Image.Image) -> List[BoundingBox]:
    """Predict the bounding boxes from the input image.

    Args:
        image (Image.Image): The input image.

    Returns:
        List[BoundingBox]: The predicted bounding boxes.
    """
    encoding = self._feature_extractor(image, return_tensors="pt")
    outputs = self._object_detection_model(**encoding)
    target_sizes = torch.tensor([image.size[::-1]])
    results = self._feature_extractor.post_process_object_detection(
        outputs, target_sizes=target_sizes, threshold=0.9
    )[0]
    bboxes = []
    for score, label, box in zip(
        results["scores"], results["labels"], results["boxes"]
    ):
        xmin, ymin, xmax, ymax = box.tolist()
        bboxes.append(
            BoundingBox(
                box_coordinates_min=CartesianCoordinate2D(x=xmin, y=ymin),
                box_coordinates_max=CartesianCoordinate2D(x=xmax, y=ymax),
                box_coordinates_center=CartesianCoordinate2D(
                    x=(xmin + xmax) / 2, y=(ymin + ymax) / 2
                ),
                label=self._object_detection_model.config.id2label[label.item()],
                score=score.item(),
            )
        )
    return bboxes

RTDETRSpatialRelationShipJudge

Bases: Model

RT-DETR spatial relationship judge model for 2D images.

Parameters:

Name Type Description Default
model_address str

The address of the model to use.

'facebook/detr-resnet-50'
revision str

The revision of the model to use.

required
name str

The name of the judge model

'detr_spatial_relationship_judge'
Source code in hemm/metrics/spatial_relationship/judges/rt_detr.py
class RTDETRSpatialRelationShipJudge(weave.Model):
    """[RT-DETR](https://huggingface.co/docs/transformers/en/model_doc/rt_detr) spatial relationship judge model for 2D images.

    Args:
        model_address (str, optional): The address of the model to use.
        revision (str, optional): The revision of the model to use.
        name (str, optional): The name of the judge model
    """

    model_address: str
    name: str
    _feature_extractor: RTDetrImageProcessor = None
    _object_detection_model: RTDetrForObjectDetection = None

    def __init__(
        self,
        model_address: str = "facebook/detr-resnet-50",
        name: str = "detr_spatial_relationship_judge",
    ):
        super().__init__(model_address=model_address, name=name)
        self._feature_extractor = RTDetrImageProcessor.from_pretrained(
            self.model_address
        )
        self._object_detection_model = RTDetrForObjectDetection.from_pretrained(
            self.model_address
        )

    @weave.op()
    def predict(self, image: Image.Image) -> List[BoundingBox]:
        """Predict the bounding boxes from the input image.

        Args:
            image (Image.Image): The input image.

        Returns:
            List[BoundingBox]: The predicted bounding boxes.
        """
        encoding = self._feature_extractor(image, return_tensors="pt")
        outputs = self._object_detection_model(**encoding)
        target_sizes = torch.tensor([image.size[::-1]])
        results = self._feature_extractor.post_process_object_detection(
            outputs, target_sizes=target_sizes, threshold=0.9
        )[0]
        bboxes = []
        for score, label, box in zip(
            results["scores"], results["labels"], results["boxes"]
        ):
            xmin, ymin, xmax, ymax = box.tolist()
            bboxes.append(
                BoundingBox(
                    box_coordinates_min=CartesianCoordinate2D(x=xmin, y=ymin),
                    box_coordinates_max=CartesianCoordinate2D(x=xmax, y=ymax),
                    box_coordinates_center=CartesianCoordinate2D(
                        x=(xmin + xmax) / 2, y=(ymin + ymax) / 2
                    ),
                    label=self._object_detection_model.config.id2label[label.item()],
                    score=score.item(),
                )
            )
        return bboxes

predict(image)

Predict the bounding boxes from the input image.

Parameters:

Name Type Description Default
image Image

The input image.

required

Returns:

Type Description
List[BoundingBox]

List[BoundingBox]: The predicted bounding boxes.

Source code in hemm/metrics/spatial_relationship/judges/rt_detr.py
@weave.op()
def predict(self, image: Image.Image) -> List[BoundingBox]:
    """Predict the bounding boxes from the input image.

    Args:
        image (Image.Image): The input image.

    Returns:
        List[BoundingBox]: The predicted bounding boxes.
    """
    encoding = self._feature_extractor(image, return_tensors="pt")
    outputs = self._object_detection_model(**encoding)
    target_sizes = torch.tensor([image.size[::-1]])
    results = self._feature_extractor.post_process_object_detection(
        outputs, target_sizes=target_sizes, threshold=0.9
    )[0]
    bboxes = []
    for score, label, box in zip(
        results["scores"], results["labels"], results["boxes"]
    ):
        xmin, ymin, xmax, ymax = box.tolist()
        bboxes.append(
            BoundingBox(
                box_coordinates_min=CartesianCoordinate2D(x=xmin, y=ymin),
                box_coordinates_max=CartesianCoordinate2D(x=xmax, y=ymax),
                box_coordinates_center=CartesianCoordinate2D(
                    x=(xmin + xmax) / 2, y=(ymin + ymax) / 2
                ),
                label=self._object_detection_model.config.id2label[label.item()],
                score=score.item(),
            )
        )
    return bboxes