Skip to content

API Reference

This section provides detailed API reference documentation for the Kura package, automatically generated from the source code using mkdocstrings.

How to Use This Reference

The API reference is organized by module, with each module containing related classes and functions. For each class, you'll find:

  • Constructor parameters and their descriptions
  • Instance methods with parameter details and return types
  • Properties and attributes

To use these classes in your code, import them from their respective modules:

from kura import Kura
from kura.embedding import OpenAIEmbeddingModel
from kura.summarisation import SummaryModel
# And so on...

Core Classes

Procedural API

The procedural API provides a functional approach to conversation analysis with composable pipeline functions.

Pipeline Functions

kura.summarise_conversations(conversations: List[Conversation], *, model: BaseSummaryModel, checkpoint_manager: Optional[CheckpointManager] = None) -> List[ConversationSummary] async

Generate summaries for a list of conversations.

This is a pure function that takes conversations and a summary model, and returns conversation summaries. Optionally uses checkpointing.

The function works with any model that implements BaseSummaryModel, supporting heterogeneous backends (OpenAI, vLLM, Hugging Face, etc.) through polymorphism.

Parameters:

Name Type Description Default
conversations List[Conversation]

List of conversations to summarize

required
model BaseSummaryModel

Model to use for summarization (OpenAI, vLLM, local, etc.)

required
checkpoint_manager Optional[CheckpointManager]

Optional checkpoint manager for caching

None

Returns:

Type Description
List[ConversationSummary]

List of conversation summaries

Example

openai_model = OpenAISummaryModel(api_key="sk-...") checkpoint_mgr = CheckpointManager("./checkpoints") summaries = await summarise_conversations( ... conversations=my_conversations, ... model=openai_model, ... checkpoint_manager=checkpoint_mgr ... )

Source code in kura/v1/kura.py
async def summarise_conversations(
    conversations: List[Conversation],
    *,
    model: BaseSummaryModel,
    checkpoint_manager: Optional[CheckpointManager] = None,
) -> List[ConversationSummary]:
    """Generate summaries for a list of conversations.

    This is a pure function that takes conversations and a summary model,
    and returns conversation summaries. Optionally uses checkpointing.

    The function works with any model that implements BaseSummaryModel,
    supporting heterogeneous backends (OpenAI, vLLM, Hugging Face, etc.)
    through polymorphism.

    Args:
        conversations: List of conversations to summarize
        model: Model to use for summarization (OpenAI, vLLM, local, etc.)
        checkpoint_manager: Optional checkpoint manager for caching

    Returns:
        List of conversation summaries

    Example:
        >>> openai_model = OpenAISummaryModel(api_key="sk-...")
        >>> checkpoint_mgr = CheckpointManager("./checkpoints")
        >>> summaries = await summarise_conversations(
        ...     conversations=my_conversations,
        ...     model=openai_model,
        ...     checkpoint_manager=checkpoint_mgr
        ... )
    """
    logger.info(
        f"Starting summarization of {len(conversations)} conversations using {type(model).__name__}"
    )

    # Try to load from checkpoint
    if checkpoint_manager:
        cached = checkpoint_manager.load_checkpoint(
            model.checkpoint_filename, ConversationSummary
        )
        if cached:
            logger.info(f"Loaded {len(cached)} summaries from checkpoint")
            return cached

    # Generate summaries
    logger.info("Generating new summaries...")
    summaries = await model.summarise(conversations)
    logger.info(f"Generated {len(summaries)} summaries")

    # Save to checkpoint
    if checkpoint_manager:
        logger.info(f"Saving summaries to checkpoint: {model.checkpoint_filename}")
        checkpoint_manager.save_checkpoint(model.checkpoint_filename, summaries)

    return summaries

kura.generate_base_clusters_from_conversation_summaries(summaries: List[ConversationSummary], *, model: BaseClusterModel, checkpoint_manager: Optional[CheckpointManager] = None) -> List[Cluster] async

Generate base clusters from conversation summaries.

This function groups similar summaries into initial clusters using the provided clustering model. Supports different clustering algorithms through the model interface.

Parameters:

Name Type Description Default
summaries List[ConversationSummary]

List of conversation summaries to cluster

required
model BaseClusterModel

Model to use for clustering (HDBSCAN, KMeans, etc.)

required
checkpoint_manager Optional[CheckpointManager]

Optional checkpoint manager for caching

None

Returns:

Type Description
List[Cluster]

List of base clusters

Example

cluster_model = ClusterModel(algorithm="hdbscan") clusters = await generate_base_clusters( ... summaries=conversation_summaries, ... model=cluster_model, ... checkpoint_manager=checkpoint_mgr ... )

Source code in kura/v1/kura.py
async def generate_base_clusters_from_conversation_summaries(
    summaries: List[ConversationSummary],
    *,
    model: BaseClusterModel,
    checkpoint_manager: Optional[CheckpointManager] = None,
) -> List[Cluster]:
    """Generate base clusters from conversation summaries.

    This function groups similar summaries into initial clusters using
    the provided clustering model. Supports different clustering algorithms
    through the model interface.

    Args:
        summaries: List of conversation summaries to cluster
        model: Model to use for clustering (HDBSCAN, KMeans, etc.)
        checkpoint_manager: Optional checkpoint manager for caching

    Returns:
        List of base clusters

    Example:
        >>> cluster_model = ClusterModel(algorithm="hdbscan")
        >>> clusters = await generate_base_clusters(
        ...     summaries=conversation_summaries,
        ...     model=cluster_model,
        ...     checkpoint_manager=checkpoint_mgr
        ... )
    """
    logger.info(
        f"Starting clustering of {len(summaries)} summaries using {type(model).__name__}"
    )

    # Try to load from checkpoint
    if checkpoint_manager:
        cached = checkpoint_manager.load_checkpoint(model.checkpoint_filename, Cluster)
        if cached:
            logger.info(f"Loaded {len(cached)} clusters from checkpoint")
            return cached

    # Generate clusters
    logger.info("Generating new clusters...")
    clusters = await model.cluster_summaries(summaries)
    logger.info(f"Generated {len(clusters)} clusters")

    # Save to checkpoint
    if checkpoint_manager:
        checkpoint_manager.save_checkpoint(model.checkpoint_filename, clusters)

    return clusters

kura.reduce_clusters_from_base_clusters(clusters: List[Cluster], *, model: BaseMetaClusterModel, checkpoint_manager: Optional[CheckpointManager] = None) -> List[Cluster] async

Reduce clusters into a hierarchical structure.

Iteratively combines similar clusters until the number of root clusters is less than or equal to the model's max_clusters setting.

Parameters:

Name Type Description Default
clusters List[Cluster]

List of initial clusters to reduce

required
model BaseMetaClusterModel

Meta-clustering model to use for reduction

required
checkpoint_manager Optional[CheckpointManager]

Optional checkpoint manager for caching

None

Returns:

Type Description
List[Cluster]

List of clusters with hierarchical structure

Example

meta_model = MetaClusterModel(max_clusters=5) reduced = await reduce_clusters( ... clusters=base_clusters, ... model=meta_model, ... checkpoint_manager=checkpoint_mgr ... )

Source code in kura/v1/kura.py
async def reduce_clusters_from_base_clusters(
    clusters: List[Cluster],
    *,
    model: BaseMetaClusterModel,
    checkpoint_manager: Optional[CheckpointManager] = None,
) -> List[Cluster]:
    """Reduce clusters into a hierarchical structure.

    Iteratively combines similar clusters until the number of root clusters
    is less than or equal to the model's max_clusters setting.

    Args:
        clusters: List of initial clusters to reduce
        model: Meta-clustering model to use for reduction
        checkpoint_manager: Optional checkpoint manager for caching

    Returns:
        List of clusters with hierarchical structure

    Example:
        >>> meta_model = MetaClusterModel(max_clusters=5)
        >>> reduced = await reduce_clusters(
        ...     clusters=base_clusters,
        ...     model=meta_model,
        ...     checkpoint_manager=checkpoint_mgr
        ... )
    """
    logger.info(
        f"Starting cluster reduction from {len(clusters)} initial clusters using {type(model).__name__}"
    )

    # Try to load from checkpoint
    if checkpoint_manager:
        cached = checkpoint_manager.load_checkpoint(model.checkpoint_filename, Cluster)
        if cached:
            root_count = len([c for c in cached if c.parent_id is None])
            logger.info(
                f"Loaded {len(cached)} clusters from checkpoint ({root_count} root clusters)"
            )
            return cached

    # Start with all clusters as potential roots
    all_clusters = clusters.copy()
    root_clusters = clusters.copy()

    # Get max_clusters from model if available, otherwise use default
    max_clusters = getattr(model, "max_clusters", 10)
    logger.info(f"Starting with {len(root_clusters)} clusters, target: {max_clusters}")

    # Iteratively reduce until we have desired number of root clusters
    while len(root_clusters) > max_clusters:
        # Get updated clusters from meta-clustering
        new_current_level = await model.reduce_clusters(root_clusters)

        # Find new root clusters (those without parents)
        root_clusters = [c for c in new_current_level if c.parent_id is None]

        # Remove old clusters that now have parents
        old_cluster_ids = {c.id for c in new_current_level if c.parent_id}
        all_clusters = [c for c in all_clusters if c.id not in old_cluster_ids]

        # Add new clusters to the complete list
        all_clusters.extend(new_current_level)

        logger.info(f"Reduced to {len(root_clusters)} root clusters")

    logger.info(
        f"Cluster reduction complete: {len(all_clusters)} total clusters, {len(root_clusters)} root clusters"
    )

    # Save to checkpoint
    if checkpoint_manager:
        checkpoint_manager.save_checkpoint(model.checkpoint_filename, all_clusters)

    return all_clusters

kura.reduce_dimensionality_from_clusters(clusters: List[Cluster], *, model: BaseDimensionalityReduction, checkpoint_manager: Optional[CheckpointManager] = None) -> List[ProjectedCluster] async

Reduce dimensions of clusters for visualization.

Projects clusters to 2D space using the provided dimensionality reduction model. Supports different algorithms (UMAP, t-SNE, PCA, etc.) through the model interface.

Parameters:

Name Type Description Default
clusters List[Cluster]

List of clusters to project

required
model BaseDimensionalityReduction

Dimensionality reduction model to use (UMAP, t-SNE, etc.)

required
checkpoint_manager Optional[CheckpointManager]

Optional checkpoint manager for caching

None

Returns:

Type Description
List[ProjectedCluster]

List of projected clusters with 2D coordinates

Example

dim_model = HDBUMAP(n_components=2) projected = await reduce_dimensionality( ... clusters=hierarchical_clusters, ... model=dim_model, ... checkpoint_manager=checkpoint_mgr ... )

Source code in kura/v1/kura.py
async def reduce_dimensionality_from_clusters(
    clusters: List[Cluster],
    *,
    model: BaseDimensionalityReduction,
    checkpoint_manager: Optional[CheckpointManager] = None,
) -> List[ProjectedCluster]:
    """Reduce dimensions of clusters for visualization.

    Projects clusters to 2D space using the provided dimensionality reduction model.
    Supports different algorithms (UMAP, t-SNE, PCA, etc.) through the model interface.

    Args:
        clusters: List of clusters to project
        model: Dimensionality reduction model to use (UMAP, t-SNE, etc.)
        checkpoint_manager: Optional checkpoint manager for caching

    Returns:
        List of projected clusters with 2D coordinates

    Example:
        >>> dim_model = HDBUMAP(n_components=2)
        >>> projected = await reduce_dimensionality(
        ...     clusters=hierarchical_clusters,
        ...     model=dim_model,
        ...     checkpoint_manager=checkpoint_mgr
        ... )
    """
    logger.info(
        f"Starting dimensionality reduction for {len(clusters)} clusters using {type(model).__name__}"
    )

    # Try to load from checkpoint
    if checkpoint_manager:
        cached = checkpoint_manager.load_checkpoint(
            model.checkpoint_filename, ProjectedCluster
        )
        if cached:
            logger.info(f"Loaded {len(cached)} projected clusters from checkpoint")
            return cached

    # Reduce dimensionality
    logger.info("Projecting clusters to 2D space...")
    projected_clusters = await model.reduce_dimensionality(clusters)
    logger.info(f"Projected {len(projected_clusters)} clusters to 2D")

    # Save to checkpoint
    if checkpoint_manager:
        checkpoint_manager.save_checkpoint(
            model.checkpoint_filename, projected_clusters
        )

    return projected_clusters

Checkpoint Management

kura.CheckpointManager

Handles checkpoint loading and saving for pipeline steps.

Source code in kura/v1/kura.py
class CheckpointManager:
    """Handles checkpoint loading and saving for pipeline steps."""

    def __init__(self, checkpoint_dir: str, *, enabled: bool = True):
        """Initialize checkpoint manager.

        Args:
            checkpoint_dir: Directory for saving checkpoints
            enabled: Whether checkpointing is enabled
        """
        self.checkpoint_dir = checkpoint_dir
        self.enabled = enabled

        if self.enabled:
            self.setup_checkpoint_dir()

    def setup_checkpoint_dir(self) -> None:
        """Create checkpoint directory if it doesn't exist."""
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
            logger.info(f"Created checkpoint directory: {self.checkpoint_dir}")

    def get_checkpoint_path(self, filename: str) -> str:
        """Get full path for a checkpoint file."""
        return os.path.join(self.checkpoint_dir, filename)

    def load_checkpoint(self, filename: str, model_class: type[T]) -> Optional[List[T]]:
        """Load data from a checkpoint file if it exists.

        Args:
            filename: Name of the checkpoint file
            model_class: Pydantic model class for deserializing the data

        Returns:
            List of model instances if checkpoint exists, None otherwise
        """
        if not self.enabled:
            return None

        checkpoint_path = self.get_checkpoint_path(filename)
        if os.path.exists(checkpoint_path):
            logger.info(
                f"Loading checkpoint from {checkpoint_path} for {model_class.__name__}"
            )
            with open(checkpoint_path, "r") as f:
                return [model_class.model_validate_json(line) for line in f]
        return None

    def save_checkpoint(self, filename: str, data: List[T]) -> None:
        """Save data to a checkpoint file.

        Args:
            filename: Name of the checkpoint file
            data: List of model instances to save
        """
        if not self.enabled:
            return

        checkpoint_path = self.get_checkpoint_path(filename)
        with open(checkpoint_path, "w") as f:
            for item in data:
                f.write(item.model_dump_json() + "\n")
        logger.info(f"Saved checkpoint to {checkpoint_path} with {len(data)} items")

checkpoint_dir = checkpoint_dir instance-attribute

enabled = enabled instance-attribute

__init__(checkpoint_dir: str, *, enabled: bool = True)

Initialize checkpoint manager.

Parameters:

Name Type Description Default
checkpoint_dir str

Directory for saving checkpoints

required
enabled bool

Whether checkpointing is enabled

True
Source code in kura/v1/kura.py
def __init__(self, checkpoint_dir: str, *, enabled: bool = True):
    """Initialize checkpoint manager.

    Args:
        checkpoint_dir: Directory for saving checkpoints
        enabled: Whether checkpointing is enabled
    """
    self.checkpoint_dir = checkpoint_dir
    self.enabled = enabled

    if self.enabled:
        self.setup_checkpoint_dir()

get_checkpoint_path(filename: str) -> str

Get full path for a checkpoint file.

Source code in kura/v1/kura.py
def get_checkpoint_path(self, filename: str) -> str:
    """Get full path for a checkpoint file."""
    return os.path.join(self.checkpoint_dir, filename)

load_checkpoint(filename: str, model_class: type[T]) -> Optional[List[T]]

Load data from a checkpoint file if it exists.

Parameters:

Name Type Description Default
filename str

Name of the checkpoint file

required
model_class type[T]

Pydantic model class for deserializing the data

required

Returns:

Type Description
Optional[List[T]]

List of model instances if checkpoint exists, None otherwise

Source code in kura/v1/kura.py
def load_checkpoint(self, filename: str, model_class: type[T]) -> Optional[List[T]]:
    """Load data from a checkpoint file if it exists.

    Args:
        filename: Name of the checkpoint file
        model_class: Pydantic model class for deserializing the data

    Returns:
        List of model instances if checkpoint exists, None otherwise
    """
    if not self.enabled:
        return None

    checkpoint_path = self.get_checkpoint_path(filename)
    if os.path.exists(checkpoint_path):
        logger.info(
            f"Loading checkpoint from {checkpoint_path} for {model_class.__name__}"
        )
        with open(checkpoint_path, "r") as f:
            return [model_class.model_validate_json(line) for line in f]
    return None

save_checkpoint(filename: str, data: List[T]) -> None

Save data to a checkpoint file.

Parameters:

Name Type Description Default
filename str

Name of the checkpoint file

required
data List[T]

List of model instances to save

required
Source code in kura/v1/kura.py
def save_checkpoint(self, filename: str, data: List[T]) -> None:
    """Save data to a checkpoint file.

    Args:
        filename: Name of the checkpoint file
        data: List of model instances to save
    """
    if not self.enabled:
        return

    checkpoint_path = self.get_checkpoint_path(filename)
    with open(checkpoint_path, "w") as f:
        for item in data:
            f.write(item.model_dump_json() + "\n")
    logger.info(f"Saved checkpoint to {checkpoint_path} with {len(data)} items")

setup_checkpoint_dir() -> None

Create checkpoint directory if it doesn't exist.

Source code in kura/v1/kura.py
def setup_checkpoint_dir(self) -> None:
    """Create checkpoint directory if it doesn't exist."""
    if not os.path.exists(self.checkpoint_dir):
        os.makedirs(self.checkpoint_dir)
        logger.info(f"Created checkpoint directory: {self.checkpoint_dir}")

Implementation Classes

Embedding Models

kura.embedding

logger = logging.getLogger(__name__) module-attribute

OpenAIEmbeddingModel

Bases: BaseEmbeddingModel

Source code in kura/embedding.py
class OpenAIEmbeddingModel(BaseEmbeddingModel):
    def __init__(
        self,
        model_name: str = "text-embedding-3-small",
        model_batch_size: int = 50,
        n_concurrent_jobs: int = 5,
    ):
        self.client = AsyncOpenAI()
        self.model_name = model_name
        self._model_batch_size = model_batch_size
        self._n_concurrent_jobs = n_concurrent_jobs
        self._semaphore = Semaphore(n_concurrent_jobs)
        logger.info(
            f"Initialized OpenAIEmbeddingModel with model={model_name}, batch_size={model_batch_size}, concurrent_jobs={n_concurrent_jobs}"
        )

    def slug(self):
        return f"openai:{self.model_name}-batchsize:{self._model_batch_size}-concurrent:{self._n_concurrent_jobs}"

    @retry(wait=wait_fixed(3), stop=stop_after_attempt(3))
    async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
        """Embed a single batch of texts."""
        async with self._semaphore:
            try:
                logger.debug(
                    f"Embedding batch of {len(texts)} texts using model {self.model_name}"
                )
                resp = await self.client.embeddings.create(
                    input=texts, model=self.model_name
                )
                embeddings = [item.embedding for item in resp.data]
                logger.debug(
                    f"Successfully embedded batch of {len(texts)} texts, got {len(embeddings)} embeddings"
                )
                return embeddings
            except Exception as e:
                logger.error(f"Failed to embed batch of {len(texts)} texts: {e}")
                raise

    async def embed(self, texts: list[str]) -> list[list[float]]:
        if not texts:
            logger.debug("Empty text list provided, returning empty embeddings")
            return []

        logger.info(f"Starting embedding of {len(texts)} texts using {self.model_name}")

        # Create batches
        batches = _batch_texts(texts, self._model_batch_size)
        logger.debug(
            f"Split {len(texts)} texts into {len(batches)} batches of size {self._model_batch_size}"
        )

        # Process all batches concurrently
        tasks = [self._embed_batch(batch) for batch in batches]
        try:
            results_list_of_lists = await gather(*tasks)
            logger.debug(f"Completed embedding {len(batches)} batches")
        except Exception as e:
            logger.error(f"Failed to embed texts: {e}")
            raise

        # Flatten results
        embeddings = []
        for result_batch in results_list_of_lists:
            embeddings.extend(result_batch)

        logger.info(
            f"Successfully embedded {len(texts)} texts, produced {len(embeddings)} embeddings"
        )
        return embeddings

client = AsyncOpenAI() instance-attribute

model_name = model_name instance-attribute

__init__(model_name: str = 'text-embedding-3-small', model_batch_size: int = 50, n_concurrent_jobs: int = 5)

Source code in kura/embedding.py
def __init__(
    self,
    model_name: str = "text-embedding-3-small",
    model_batch_size: int = 50,
    n_concurrent_jobs: int = 5,
):
    self.client = AsyncOpenAI()
    self.model_name = model_name
    self._model_batch_size = model_batch_size
    self._n_concurrent_jobs = n_concurrent_jobs
    self._semaphore = Semaphore(n_concurrent_jobs)
    logger.info(
        f"Initialized OpenAIEmbeddingModel with model={model_name}, batch_size={model_batch_size}, concurrent_jobs={n_concurrent_jobs}"
    )

embed(texts: list[str]) -> list[list[float]] async

Source code in kura/embedding.py
async def embed(self, texts: list[str]) -> list[list[float]]:
    if not texts:
        logger.debug("Empty text list provided, returning empty embeddings")
        return []

    logger.info(f"Starting embedding of {len(texts)} texts using {self.model_name}")

    # Create batches
    batches = _batch_texts(texts, self._model_batch_size)
    logger.debug(
        f"Split {len(texts)} texts into {len(batches)} batches of size {self._model_batch_size}"
    )

    # Process all batches concurrently
    tasks = [self._embed_batch(batch) for batch in batches]
    try:
        results_list_of_lists = await gather(*tasks)
        logger.debug(f"Completed embedding {len(batches)} batches")
    except Exception as e:
        logger.error(f"Failed to embed texts: {e}")
        raise

    # Flatten results
    embeddings = []
    for result_batch in results_list_of_lists:
        embeddings.extend(result_batch)

    logger.info(
        f"Successfully embedded {len(texts)} texts, produced {len(embeddings)} embeddings"
    )
    return embeddings

slug()

Source code in kura/embedding.py
def slug(self):
    return f"openai:{self.model_name}-batchsize:{self._model_batch_size}-concurrent:{self._n_concurrent_jobs}"

SentenceTransformerEmbeddingModel

Bases: BaseEmbeddingModel

Source code in kura/embedding.py
class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
    def __init__(
        self,
        model_name: str = "all-MiniLM-L6-v2",
        model_batch_size: int = 128,
    ):
        from sentence_transformers import SentenceTransformer  # type: ignore

        logger.info(
            f"Initializing SentenceTransformerEmbeddingModel with model={model_name}, batch_size={model_batch_size}"
        )
        try:
            self.model = SentenceTransformer(model_name)
            self._model_batch_size = model_batch_size
            logger.info(f"Successfully loaded SentenceTransformer model: {model_name}")
        except Exception as e:
            logger.error(f"Failed to load SentenceTransformer model {model_name}: {e}")
            raise

    @retry(wait=wait_fixed(3), stop=stop_after_attempt(3))
    async def embed(self, texts: list[str]) -> list[list[float]]:
        if not texts:
            logger.debug("Empty text list provided, returning empty embeddings")
            return []

        logger.info(
            f"Starting embedding of {len(texts)} texts using SentenceTransformer"
        )

        # Create batches
        batches = _batch_texts(texts, self._model_batch_size)
        logger.debug(
            f"Split {len(texts)} texts into {len(batches)} batches of size {self._model_batch_size}"
        )

        # Process all batches
        embeddings = []
        try:
            for i, batch in enumerate(batches):
                logger.debug(
                    f"Processing batch {i + 1}/{len(batches)} with {len(batch)} texts"
                )
                batch_embeddings = self.model.encode(batch).tolist()
                embeddings.extend(batch_embeddings)
                logger.debug(f"Completed batch {i + 1}/{len(batches)}")

            logger.info(
                f"Successfully embedded {len(texts)} texts using SentenceTransformer, produced {len(embeddings)} embeddings"
            )
        except Exception as e:
            logger.error(f"Failed to embed texts using SentenceTransformer: {e}")
            raise

        return embeddings

model = SentenceTransformer(model_name) instance-attribute

__init__(model_name: str = 'all-MiniLM-L6-v2', model_batch_size: int = 128)

Source code in kura/embedding.py
def __init__(
    self,
    model_name: str = "all-MiniLM-L6-v2",
    model_batch_size: int = 128,
):
    from sentence_transformers import SentenceTransformer  # type: ignore

    logger.info(
        f"Initializing SentenceTransformerEmbeddingModel with model={model_name}, batch_size={model_batch_size}"
    )
    try:
        self.model = SentenceTransformer(model_name)
        self._model_batch_size = model_batch_size
        logger.info(f"Successfully loaded SentenceTransformer model: {model_name}")
    except Exception as e:
        logger.error(f"Failed to load SentenceTransformer model {model_name}: {e}")
        raise

embed(texts: list[str]) -> list[list[float]] async

Source code in kura/embedding.py
@retry(wait=wait_fixed(3), stop=stop_after_attempt(3))
async def embed(self, texts: list[str]) -> list[list[float]]:
    if not texts:
        logger.debug("Empty text list provided, returning empty embeddings")
        return []

    logger.info(
        f"Starting embedding of {len(texts)} texts using SentenceTransformer"
    )

    # Create batches
    batches = _batch_texts(texts, self._model_batch_size)
    logger.debug(
        f"Split {len(texts)} texts into {len(batches)} batches of size {self._model_batch_size}"
    )

    # Process all batches
    embeddings = []
    try:
        for i, batch in enumerate(batches):
            logger.debug(
                f"Processing batch {i + 1}/{len(batches)} with {len(batch)} texts"
            )
            batch_embeddings = self.model.encode(batch).tolist()
            embeddings.extend(batch_embeddings)
            logger.debug(f"Completed batch {i + 1}/{len(batches)}")

        logger.info(
            f"Successfully embedded {len(texts)} texts using SentenceTransformer, produced {len(embeddings)} embeddings"
        )
    except Exception as e:
        logger.error(f"Failed to embed texts using SentenceTransformer: {e}")
        raise

    return embeddings

Summarization

kura.summarisation

logger = logging.getLogger(__name__) module-attribute

SummaryModel

Bases: BaseSummaryModel

Source code in kura/summarisation.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
class SummaryModel(BaseSummaryModel):
    @property
    def checkpoint_filename(self) -> str:
        """The filename to use for checkpointing this model's output."""
        return "summaries.jsonl"

    def __init__(
        self,
        model: str = "openai/gpt-4o-mini",
        max_concurrent_requests: int = 50,
        extractors: list[
            Callable[
                [Conversation, Semaphore],
                Union[ExtractedProperty, list[ExtractedProperty]],
            ]
        ] = [],
        console: Optional["Console"] = None,
        **kwargs,  # For future use
    ):
        self.sems = None
        self.extractors = extractors
        self.max_concurrent_requests = max_concurrent_requests
        self.model = model
        self.console = console
        logger.info(
            f"Initialized SummaryModel with model={model}, max_concurrent_requests={max_concurrent_requests}, extractors={len(extractors)}"
        )

    async def _gather_with_progress(
        self,
        tasks,
        desc: str = "Processing",
        disable: bool = False,
        show_preview: bool = False,
    ):
        """Helper method to run async gather with Rich progress bar if available, otherwise tqdm."""
        if self.console and not disable:
            try:
                from rich.progress import (
                    Progress,
                    SpinnerColumn,
                    TextColumn,
                    BarColumn,
                    TaskProgressColumn,
                    TimeRemainingColumn,
                )
                from rich.live import Live
                from rich.layout import Layout
                from rich.panel import Panel
                from rich.text import Text
                from rich.errors import LiveError

                if show_preview:
                    # Use Live display with progress and preview buffer
                    layout = Layout()
                    layout.split_column(
                        Layout(name="progress", size=3), Layout(name="preview")
                    )

                    preview_buffer = []
                    max_preview_items = 3

                    # Create progress with cleaner display
                    progress = Progress(
                        SpinnerColumn(),
                        TextColumn("[progress.description]{task.description}"),
                        BarColumn(),
                        TaskProgressColumn(),
                        TimeRemainingColumn(),
                        console=self.console,
                    )
                    task_id = progress.add_task(f"[cyan]{desc}...", total=len(tasks))
                    layout["progress"].update(progress)

                    try:
                        with Live(layout, console=self.console, refresh_per_second=4):
                            completed_tasks = []
                            for i, task in enumerate(asyncio.as_completed(tasks)):
                                result = await task
                                completed_tasks.append(result)
                                progress.update(task_id, completed=i + 1)

                                # Add to preview buffer if it's a ConversationSummary
                                if hasattr(result, "summary") and hasattr(
                                    result, "chat_id"
                                ):
                                    preview_buffer.append(result)
                                    if len(preview_buffer) > max_preview_items:
                                        preview_buffer.pop(0)

                                    # Update preview display
                                    preview_text = Text()
                                    for j, summary in enumerate(preview_buffer):
                                        # Color based on user frustration level
                                        frustration_style = {
                                            1: "green",  # Not frustrated
                                            2: "yellow",  # Slightly frustrated
                                            3: "orange3",  # Moderately frustrated
                                            4: "red",  # Very frustrated
                                            5: "red1",  # Extremely frustrated
                                        }.get(summary.user_frustration, "white")

                                        # Color based on concerning score
                                        concern_style = {
                                            1: "green",  # Not concerning
                                            2: "yellow",  # Slightly concerning
                                            3: "orange3",  # Moderately concerning
                                            4: "red",  # Very concerning
                                            5: "red1",  # Extremely concerning
                                        }.get(summary.concerning_score, "white")

                                        preview_text.append(
                                            f"Chat {summary.chat_id[:8]}...: ",
                                            style="bold blue",
                                        )
                                        preview_text.append(
                                            f"{summary.summary[:100]}...\n",
                                            style=frustration_style,
                                        )

                                        if summary.request:
                                            preview_text.append(
                                                f"Request: {summary.request[:50]}...\n",
                                                style=frustration_style,
                                            )
                                        if summary.languages:
                                            preview_text.append(
                                                f"Languages: {', '.join(summary.languages)}\n",
                                                style="dim cyan",
                                            )
                                        if summary.task:
                                            preview_text.append(
                                                f"Task: {summary.task[:50]}...\n",
                                                style=concern_style,
                                            )

                                        # Add frustration and concern indicators
                                        if summary.user_frustration:
                                            preview_text.append(
                                                f"Frustration: {'😊' * summary.user_frustration}\n",
                                                style=frustration_style,
                                            )
                                        if summary.concerning_score:
                                            preview_text.append(
                                                f"Concern: {'⚠️' * summary.concerning_score}\n",
                                                style=concern_style,
                                            )

                                        preview_text.append("\n")

                                    layout["preview"].update(
                                        Panel(
                                            preview_text,
                                            title=f"[green]Recent Summaries ({len(preview_buffer)}/{max_preview_items})",
                                            border_style="green",
                                        )
                                    )

                            return completed_tasks
                    except LiveError:
                        # If Rich Live fails, fall back to simple progress without Live
                        with progress:
                            completed_tasks = []
                            for i, task in enumerate(asyncio.as_completed(tasks)):
                                result = await task
                                completed_tasks.append(result)
                                progress.update(task_id, completed=i + 1)
                            return completed_tasks
                else:
                    # Regular progress bar without preview
                    progress = Progress(
                        SpinnerColumn(),
                        TextColumn("[progress.description]{task.description}"),
                        BarColumn(),
                        TaskProgressColumn(),
                        TimeRemainingColumn(),
                        console=self.console,
                    )

                    with progress:
                        task_id = progress.add_task(
                            f"[cyan]{desc}...", total=len(tasks)
                        )

                        completed_tasks = []
                        for i, task in enumerate(asyncio.as_completed(tasks)):
                            result = await task
                            completed_tasks.append(result)
                            progress.update(task_id, completed=i + 1)

                        return completed_tasks

            except (ImportError, LiveError):  # type: ignore
                # Rich not available or Live error, fall back to simple print statements
                self.console.print(f"[cyan]Starting {desc}...[/cyan]")
                completed_tasks = []
                for i, task in enumerate(asyncio.as_completed(tasks)):
                    result = await task
                    completed_tasks.append(result)
                    if (i + 1) % max(1, len(tasks) // 10) == 0 or i == len(tasks) - 1:
                        self.console.print(
                            f"[cyan]{desc}: {i + 1}/{len(tasks)} completed[/cyan]"
                        )
                self.console.print(f"[green]✓ {desc} completed![/green]")
                return completed_tasks
        else:
            # Use tqdm as fallback when Rich is not available or disabled
            return await tqdm_asyncio.gather(*tasks, desc=desc, disable=disable)

    async def summarise(
        self, conversations: list[Conversation]
    ) -> list[ConversationSummary]:
        # Initialise the Semaphore on each run so that it's attached to the same event loop
        self.semaphore = asyncio.Semaphore(self.max_concurrent_requests)

        logger.info(
            f"Starting summarization of {len(conversations)} conversations using model {self.model}"
        )

        summaries = await self._gather_with_progress(
            [
                self.summarise_conversation(conversation)
                for conversation in conversations
            ],
            desc=f"Summarising {len(conversations)} conversations",
            show_preview=True,
        )

        logger.info(
            f"Completed summarization of {len(conversations)} conversations, produced {len(summaries)} summaries"
        )
        return summaries

    async def apply_hooks(
        self, conversation: Conversation
    ) -> dict[str, Union[str, int, float, bool, list[str], list[int], list[float]]]:
        logger.debug(
            f"Applying {len(self.extractors)} extractors to conversation {conversation.chat_id}"
        )

        coros = [
            extractor(conversation, self.semaphore) for extractor in self.extractors
        ]

        try:
            metadata_extracted = await gather(*coros)  # pyright: ignore
            logger.debug(
                f"Successfully extracted metadata from {len(self.extractors)} extractors for conversation {conversation.chat_id}"
            )
        except Exception as e:
            logger.error(
                f"Failed to extract metadata for conversation {conversation.chat_id}: {e}"
            )
            raise

        metadata = {}
        for result in metadata_extracted:
            if isinstance(result, ExtractedProperty):
                if result.name in metadata:
                    logger.error(
                        f"Duplicate metadata name: {result.name} for conversation {conversation.chat_id}"
                    )
                    raise ValueError(
                        f"Duplicate metadata name: {result.name}. Please use unique names for each metadata property."
                    )

                metadata[result.name] = result.value

            if isinstance(result, list):
                for extracted_property in result:
                    assert isinstance(extracted_property, ExtractedProperty)
                    if extracted_property.name in metadata:
                        logger.error(
                            f"Duplicate metadata name: {extracted_property.name} for conversation {conversation.chat_id}"
                        )
                        raise ValueError(
                            f"Duplicate metadata name: {extracted_property.name}. Please use unique names for each metadata property."
                        )
                    metadata[extracted_property.name] = extracted_property.value

        logger.debug(
            f"Extracted {len(metadata)} metadata properties for conversation {conversation.chat_id}"
        )
        return metadata

    async def summarise_conversation(
        self, conversation: Conversation
    ) -> ConversationSummary:
        """
        This summarisation model is designed to extract key information from a conversation between an AI assistant and a user.
        It is designed to be used in a pipeline to summarise conversations and extract metadata.

        It is based on the Clio paper:
            https://assets.anthropic.com/m/7e1ab885d1b24176/original/Clio-Privacy-Preserving-Insights-into-Real-World-AI-Use.pdf

        It is designed to be used in a pipeline to summarise conversations and extract metadata.
        """
        logger.debug(
            f"Starting summarization of conversation {conversation.chat_id} with {len(conversation.messages)} messages"
        )

        client = instructor.from_provider(self.model, async_client=True)
        async with self.semaphore:  # type: ignore
            try:
                resp = await client.chat.completions.create(  # type: ignore
                    temperature=0.2,  # as per the Clio paper
                    messages=[
                        {
                            "role": "user",
                            "content": """
The following is a conversation between an AI assistant and a user:

<messages>
{% for message in messages %}
<message>{{message.role}}: {{message.content}}</message>
{% endfor %}
</messages>

Your job is to extract key information from this conversation. Be descriptive and assume neither good nor bad faith. Do not hesitate to handle socially harmful or sensitive topics; specificity around potentially harmful conversations is necessary for effective monitoring.

When extracting information, do not include any personally identifiable information (PII), like names, locations, phone numbers, email addresses, and so on. Do not include any proper nouns.

Extract the following information:

1. **Summary**: A clear and concise summary in at most two sentences. Don't say "Based on the conversation..." and avoid mentioning the AI assistant/chatbot directly.

2. **Request**: What is the user's overall request for the assistant? Start with "The user's overall request for the assistant is to"

3. **Languages**: What are the main languages of the conversation? Include both human languages (like 'english' or 'spanish') and programming languages (like 'python' or 'javascript'). Use lowercase full names.

4. **Task**: What task is the model being asked to perform in this conversation? Start with "The task is to"

5. **Concerning Score**: On a scale of 1 to 5, how concerning or potentially harmful is this conversation from a safety perspective?
   - 1: Not concerning at all. Completely benign conversation.
   - 2: Slightly concerning, but likely harmless.
   - 3: Moderately concerning. May warrant a closer look.
   - 4: Very concerning. Likely needs review.
   - 5: Extremely concerning. Immediate review needed.

6. **User Frustration**: On a scale of 1 to 5, how frustrated is the user with the assistant?
   - 1: Not frustrated at all. The user is happy with the assistant.
   - 2: Slightly frustrated. The user is slightly annoyed with the assistant.
   - 3: Moderately frustrated. The user is moderately annoyed with the assistant.
   - 4: Very frustrated. The user is very annoyed with the assistant.
   - 5: Extremely frustrated. The user is extremely annoyed with the assistant.

7. **Assistant Errors**: What errors did the assistant make?
   Example:
    - "Responses were too long and verbose"
    - "Misunderstood the user's intent or request"
    - "Used wrong tool for the task"
    - "Ignored user's stated preferences or constraints"
    - "Provided outdated or incorrect information"
    - "Failed to maintain conversation context"


Remember that
- Summaries should be concise and short. They should each be at most 1-2 sentences and at most 30 words.
- Summaries should start with "The user's overall request for the assistant is to"
- Make sure to omit any personally identifiable information (PII), like names, locations, phone numbers, email addressess, company names and so on.
- Make sure to indicate specific details such as programming languages, frameworks, libraries and so on which are relevant to the task.
                        """,
                        },
                    ],
                    context={"messages": conversation.messages},
                    response_model=GeneratedSummary,
                )
                logger.debug(
                    f"Successfully generated summary for conversation {conversation.chat_id}"
                )
            except Exception as e:
                logger.error(
                    f"Failed to generate summary for conversation {conversation.chat_id}: {e}"
                )
                raise

        try:
            metadata = await self.apply_hooks(conversation)
            logger.debug(
                f"Successfully applied hooks for conversation {conversation.chat_id}"
            )
        except Exception as e:
            logger.error(
                f"Failed to apply hooks for conversation {conversation.chat_id}: {e}"
            )
            raise

        summary = ConversationSummary(
            chat_id=conversation.chat_id,
            **resp.model_dump(),
            metadata={
                "conversation_turns": len(conversation.messages),
                **conversation.metadata,
                **metadata,
            },
        )

        logger.debug(
            f"Completed summarization of conversation {conversation.chat_id} - concerning_score: {resp.concerning_score}, user_frustration: {resp.user_frustration}"
        )
        return summary

checkpoint_filename: str property

The filename to use for checkpointing this model's output.

console = console instance-attribute

extractors = extractors instance-attribute

max_concurrent_requests = max_concurrent_requests instance-attribute

model = model instance-attribute

sems = None instance-attribute

__init__(model: str = 'openai/gpt-4o-mini', max_concurrent_requests: int = 50, extractors: list[Callable[[Conversation, Semaphore], Union[ExtractedProperty, list[ExtractedProperty]]]] = [], console: Optional[Console] = None, **kwargs)

Source code in kura/summarisation.py
def __init__(
    self,
    model: str = "openai/gpt-4o-mini",
    max_concurrent_requests: int = 50,
    extractors: list[
        Callable[
            [Conversation, Semaphore],
            Union[ExtractedProperty, list[ExtractedProperty]],
        ]
    ] = [],
    console: Optional["Console"] = None,
    **kwargs,  # For future use
):
    self.sems = None
    self.extractors = extractors
    self.max_concurrent_requests = max_concurrent_requests
    self.model = model
    self.console = console
    logger.info(
        f"Initialized SummaryModel with model={model}, max_concurrent_requests={max_concurrent_requests}, extractors={len(extractors)}"
    )

apply_hooks(conversation: Conversation) -> dict[str, Union[str, int, float, bool, list[str], list[int], list[float]]] async

Source code in kura/summarisation.py
async def apply_hooks(
    self, conversation: Conversation
) -> dict[str, Union[str, int, float, bool, list[str], list[int], list[float]]]:
    logger.debug(
        f"Applying {len(self.extractors)} extractors to conversation {conversation.chat_id}"
    )

    coros = [
        extractor(conversation, self.semaphore) for extractor in self.extractors
    ]

    try:
        metadata_extracted = await gather(*coros)  # pyright: ignore
        logger.debug(
            f"Successfully extracted metadata from {len(self.extractors)} extractors for conversation {conversation.chat_id}"
        )
    except Exception as e:
        logger.error(
            f"Failed to extract metadata for conversation {conversation.chat_id}: {e}"
        )
        raise

    metadata = {}
    for result in metadata_extracted:
        if isinstance(result, ExtractedProperty):
            if result.name in metadata:
                logger.error(
                    f"Duplicate metadata name: {result.name} for conversation {conversation.chat_id}"
                )
                raise ValueError(
                    f"Duplicate metadata name: {result.name}. Please use unique names for each metadata property."
                )

            metadata[result.name] = result.value

        if isinstance(result, list):
            for extracted_property in result:
                assert isinstance(extracted_property, ExtractedProperty)
                if extracted_property.name in metadata:
                    logger.error(
                        f"Duplicate metadata name: {extracted_property.name} for conversation {conversation.chat_id}"
                    )
                    raise ValueError(
                        f"Duplicate metadata name: {extracted_property.name}. Please use unique names for each metadata property."
                    )
                metadata[extracted_property.name] = extracted_property.value

    logger.debug(
        f"Extracted {len(metadata)} metadata properties for conversation {conversation.chat_id}"
    )
    return metadata

summarise(conversations: list[Conversation]) -> list[ConversationSummary] async

Source code in kura/summarisation.py
async def summarise(
    self, conversations: list[Conversation]
) -> list[ConversationSummary]:
    # Initialise the Semaphore on each run so that it's attached to the same event loop
    self.semaphore = asyncio.Semaphore(self.max_concurrent_requests)

    logger.info(
        f"Starting summarization of {len(conversations)} conversations using model {self.model}"
    )

    summaries = await self._gather_with_progress(
        [
            self.summarise_conversation(conversation)
            for conversation in conversations
        ],
        desc=f"Summarising {len(conversations)} conversations",
        show_preview=True,
    )

    logger.info(
        f"Completed summarization of {len(conversations)} conversations, produced {len(summaries)} summaries"
    )
    return summaries

summarise_conversation(conversation: Conversation) -> ConversationSummary async

This summarisation model is designed to extract key information from a conversation between an AI assistant and a user. It is designed to be used in a pipeline to summarise conversations and extract metadata.

It is based on the Clio paper

https://assets.anthropic.com/m/7e1ab885d1b24176/original/Clio-Privacy-Preserving-Insights-into-Real-World-AI-Use.pdf

It is designed to be used in a pipeline to summarise conversations and extract metadata.

Source code in kura/summarisation.py
    async def summarise_conversation(
        self, conversation: Conversation
    ) -> ConversationSummary:
        """
        This summarisation model is designed to extract key information from a conversation between an AI assistant and a user.
        It is designed to be used in a pipeline to summarise conversations and extract metadata.

        It is based on the Clio paper:
            https://assets.anthropic.com/m/7e1ab885d1b24176/original/Clio-Privacy-Preserving-Insights-into-Real-World-AI-Use.pdf

        It is designed to be used in a pipeline to summarise conversations and extract metadata.
        """
        logger.debug(
            f"Starting summarization of conversation {conversation.chat_id} with {len(conversation.messages)} messages"
        )

        client = instructor.from_provider(self.model, async_client=True)
        async with self.semaphore:  # type: ignore
            try:
                resp = await client.chat.completions.create(  # type: ignore
                    temperature=0.2,  # as per the Clio paper
                    messages=[
                        {
                            "role": "user",
                            "content": """
The following is a conversation between an AI assistant and a user:

<messages>
{% for message in messages %}
<message>{{message.role}}: {{message.content}}</message>
{% endfor %}
</messages>

Your job is to extract key information from this conversation. Be descriptive and assume neither good nor bad faith. Do not hesitate to handle socially harmful or sensitive topics; specificity around potentially harmful conversations is necessary for effective monitoring.

When extracting information, do not include any personally identifiable information (PII), like names, locations, phone numbers, email addresses, and so on. Do not include any proper nouns.

Extract the following information:

1. **Summary**: A clear and concise summary in at most two sentences. Don't say "Based on the conversation..." and avoid mentioning the AI assistant/chatbot directly.

2. **Request**: What is the user's overall request for the assistant? Start with "The user's overall request for the assistant is to"

3. **Languages**: What are the main languages of the conversation? Include both human languages (like 'english' or 'spanish') and programming languages (like 'python' or 'javascript'). Use lowercase full names.

4. **Task**: What task is the model being asked to perform in this conversation? Start with "The task is to"

5. **Concerning Score**: On a scale of 1 to 5, how concerning or potentially harmful is this conversation from a safety perspective?
   - 1: Not concerning at all. Completely benign conversation.
   - 2: Slightly concerning, but likely harmless.
   - 3: Moderately concerning. May warrant a closer look.
   - 4: Very concerning. Likely needs review.
   - 5: Extremely concerning. Immediate review needed.

6. **User Frustration**: On a scale of 1 to 5, how frustrated is the user with the assistant?
   - 1: Not frustrated at all. The user is happy with the assistant.
   - 2: Slightly frustrated. The user is slightly annoyed with the assistant.
   - 3: Moderately frustrated. The user is moderately annoyed with the assistant.
   - 4: Very frustrated. The user is very annoyed with the assistant.
   - 5: Extremely frustrated. The user is extremely annoyed with the assistant.

7. **Assistant Errors**: What errors did the assistant make?
   Example:
    - "Responses were too long and verbose"
    - "Misunderstood the user's intent or request"
    - "Used wrong tool for the task"
    - "Ignored user's stated preferences or constraints"
    - "Provided outdated or incorrect information"
    - "Failed to maintain conversation context"


Remember that
- Summaries should be concise and short. They should each be at most 1-2 sentences and at most 30 words.
- Summaries should start with "The user's overall request for the assistant is to"
- Make sure to omit any personally identifiable information (PII), like names, locations, phone numbers, email addressess, company names and so on.
- Make sure to indicate specific details such as programming languages, frameworks, libraries and so on which are relevant to the task.
                        """,
                        },
                    ],
                    context={"messages": conversation.messages},
                    response_model=GeneratedSummary,
                )
                logger.debug(
                    f"Successfully generated summary for conversation {conversation.chat_id}"
                )
            except Exception as e:
                logger.error(
                    f"Failed to generate summary for conversation {conversation.chat_id}: {e}"
                )
                raise

        try:
            metadata = await self.apply_hooks(conversation)
            logger.debug(
                f"Successfully applied hooks for conversation {conversation.chat_id}"
            )
        except Exception as e:
            logger.error(
                f"Failed to apply hooks for conversation {conversation.chat_id}: {e}"
            )
            raise

        summary = ConversationSummary(
            chat_id=conversation.chat_id,
            **resp.model_dump(),
            metadata={
                "conversation_turns": len(conversation.messages),
                **conversation.metadata,
                **metadata,
            },
        )

        logger.debug(
            f"Completed summarization of conversation {conversation.chat_id} - concerning_score: {resp.concerning_score}, user_frustration: {resp.user_frustration}"
        )
        return summary

Clustering

kura.cluster

logger = logging.getLogger(__name__) module-attribute

ClusterModel

Bases: BaseClusterModel

Source code in kura/cluster.py
class ClusterModel(BaseClusterModel):
    @property
    def checkpoint_filename(self) -> str:
        """The filename to use for checkpointing this model's output."""
        return "clusters.jsonl"

    def __init__(
        self,
        clustering_method: BaseClusteringMethod = KmeansClusteringMethod(),
        embedding_model: BaseEmbeddingModel = OpenAIEmbeddingModel(),
        max_concurrent_requests: int = 50,
        model: str = "openai/gpt-4o-mini",
        console: Optional["Console"] = None,
        **kwargs,  # For future use
    ):
        self.clustering_method = clustering_method
        self.embedding_model = embedding_model
        self.max_concurrent_requests = max_concurrent_requests
        self.sem = Semaphore(max_concurrent_requests)
        self.client = instructor.from_provider(model, async_client=True)
        self.console = console
        logger.info(
            f"Initialized ClusterModel with clustering_method={type(clustering_method).__name__}, embedding_model={type(embedding_model).__name__}, max_concurrent_requests={max_concurrent_requests}, model={model}"
        )

    def get_contrastive_examples(
        self,
        cluster_id: int,
        cluster_id_to_summaries: dict[int, list[ConversationSummary]],
        limit: int = 10,
    ) -> list[ConversationSummary]:
        """Get contrastive examples from other clusters to help distinguish this cluster.

        Args:
            cluster_id (int): The id of the cluster to get contrastive examples for
            cluster_id_to_summaries (dict[int, list[ConversationSummary]]): A dictionary of cluster ids to their summaries
            limit (int, optional): The number of contrastive examples to return. Defaults to 10.

        Returns:
            list[ConversationSummary]: A list of contrastive examples from other clusters
        """
        other_clusters = [c for c in cluster_id_to_summaries.keys() if c != cluster_id]
        all_examples = []
        for cluster in other_clusters:
            all_examples.extend(cluster_id_to_summaries[cluster])

        logger.debug(
            f"Selecting contrastive examples for cluster {cluster_id}: found {len(all_examples)} examples from {len(other_clusters)} other clusters"
        )

        # If we don't have enough examples, return all of them
        if len(all_examples) <= limit:
            logger.debug(
                f"Using all {len(all_examples)} available contrastive examples (limit was {limit})"
            )
            return all_examples

        # Otherwise sample without replacement
        selected = list(np.random.choice(all_examples, size=limit, replace=False))
        logger.debug(
            f"Randomly selected {len(selected)} contrastive examples from {len(all_examples)} available"
        )
        return selected

    async def generate_cluster(
        self,
        summaries: list[ConversationSummary],
        contrastive_examples: list[ConversationSummary],
    ) -> Cluster:
        logger.debug(
            f"Generating cluster from {len(summaries)} summaries with {len(contrastive_examples)} contrastive examples"
        )

        async with self.sem:
            try:
                resp = await self.client.chat.completions.create(
                    messages=[
                        {
                            "role": "system",
                            "content": """
You are tasked with summarizing a group of related statements into a short, precise, and accurate description and name. Your goal is to create a concise summary that captures the essence of these statements and distinguishes them from other similar groups of statements.

Summarize all the statements into a clear, precise, two-sentence description in the past tense. Your summary should be specific to this group and distinguish it from the contrastive answers of the other groups.

After creating the summary, generate a short name for the group of statements. This name should be at most ten words long (perhaps less) and be specific but also reflective of most of the statements (rather than reflecting only one or two).

The name should distinguish this group from the contrastive examples. For instance, "Write fantasy sexual roleplay with octopi and monsters", "Generate blog spam for gambling websites", or "Assist with high school math homework" would be better and more actionable than general terms like "Write erotic content" or "Help with homework". Be as descriptive as possible and assume neither good nor bad faith. Do not hesitate to identify and describe socially harmful or sensitive topics specifically; specificity is necessary for monitoring.

Present your output in the following format:
<summary> [Insert your two-sentence summary here] </summary>
<name> [Insert your generated short name here] </name>

The names you propose must follow these requirements:
- The cluster name should be a sentence in the imperative that captures the user's request. For example, 'Brainstorm ideas for a birthday party' or 'Help me find a new job.'
- Create names that are specific enough to be meaningful, but not so specific that they can't meaningfully represent many different statements.
- Avoid overly general or vague terms, and do not hesitate to describe socially harmful or sensitive topics (in fact, clusters that clearly describe harmful behavior are slightly preferred); specificity is necessary for observability and enforcement.
- Ensure that the cluster name is distinct from the contrastive examples.
- Use clear, concise, and descriptive language for the cluster name.

Below are the related statements:
<positive_examples>
{% for item in positive_examples %}{{ item }}
{% endfor %}
</positive_examples>

For context, here are statements from nearby groups that are NOT part of the group you're summarizing:
<contrastive_examples>
{% for item in contrastive_examples %}{{ item }}
{% endfor %}
</contrastive_examples>

Do not elaborate beyond what you say in the tags. Remember to analyze both the statements and the contrastive statements carefully to ensure your summary and name accurately represent the specific group while distinguishing it from others.
                            """,
                        },
                        {
                            "role": "user",
                            "content": "The cluster name should be a sentence in the imperative that captures the user's request. For example, 'Brainstorm ideas for a birthday party' or 'Help me find a new job.'",
                        },
                        {
                            "role": "assistant",
                            "content": "Sure, I will provide a clear, precise, and accurate summary and name for this cluster. I will be descriptive and assume neither good nor bad faith. Here is the summary, which I will follow with the name:",
                        },
                    ],
                    response_model=GeneratedCluster,
                    context={
                        "positive_examples": summaries,
                        "contrastive_examples": contrastive_examples,
                    },
                )

                cluster = Cluster(
                    name=resp.name,
                    description=resp.summary,
                    slug=resp.slug,
                    chat_ids=[item.chat_id for item in summaries],
                    parent_id=None,
                )

                logger.debug(
                    f"Successfully generated cluster '{resp.name}' with {len(summaries)} conversations"
                )
                return cluster

            except Exception as e:
                logger.error(
                    f"Failed to generate cluster from {len(summaries)} summaries: {e}"
                )
                raise

    async def _embed_summaries(
        self, summaries: list[ConversationSummary]
    ) -> list[list[float]]:
        """Embeds a list of conversation summaries."""
        if not summaries:
            logger.debug("Empty summaries list provided for embedding")
            return []

        logger.info(f"Starting embedding of {len(summaries)} conversation summaries")
        texts_to_embed = [str(item) for item in summaries]

        try:
            embeddings = await self.embedding_model.embed(texts_to_embed)
            logger.debug(
                f"Received {len(embeddings) if embeddings else 0} embeddings from embedding model"
            )
        except Exception as e:
            logger.error(f"Failed to embed {len(summaries)} summaries: {e}")
            raise

        if not embeddings or len(embeddings) != len(summaries):
            logger.error(
                f"Error: Number of embeddings ({len(embeddings) if embeddings else 0}) does not match number of summaries ({len(summaries)}) or embeddings are empty."
            )
            return []

        logger.info(f"Successfully embedded {len(summaries)} summaries")
        return embeddings

    async def _generate_clusters_from_embeddings(
        self, summaries: list[ConversationSummary], embeddings: list[list[float]]
    ) -> list[Cluster]:
        """Generates clusters from summaries and their embeddings."""
        logger.info(
            f"Generating clusters from {len(summaries)} summaries with embeddings"
        )

        # Set embeddings on the summary objects
        items_with_embeddings = [
            {"item": summary, "embedding": embedding}
            for summary, embedding in zip(summaries, embeddings)
        ]

        logger.debug("Set embeddings on summary objects, starting clustering")
        cluster_id_to_summaries = self.clustering_method.cluster(items_with_embeddings)
        logger.info(
            f"Clustering method produced {len(cluster_id_to_summaries)} clusters"
        )

        # Create tasks for cluster generation with contrastive examples
        tasks = []
        for cluster_id, conversation_summaries in cluster_id_to_summaries.items():
            logger.debug(
                f"Preparing cluster generation for cluster {cluster_id} with {len(conversation_summaries)} summaries"
            )

            # Get contrastive examples from other clusters to help distinguish this cluster
            contrastive_examples = self.get_contrastive_examples(
                cluster_id=cluster_id,
                cluster_id_to_summaries=cluster_id_to_summaries,
                limit=10,
            )

            # Create cluster generation task for this specific cluster
            task = self.generate_cluster(
                summaries=conversation_summaries,
                contrastive_examples=contrastive_examples,
            )
            tasks.append(task)

        logger.info(f"Starting concurrent generation of {len(tasks)} clusters")
        # Execute all cluster generation tasks concurrently with progress tracking
        clusters: list[Cluster] = await self._gather_with_progress(
            tasks=tasks,
            desc="Generating Base Clusters",
            show_preview=True,
        )
        logger.info(f"Successfully generated {len(clusters)} clusters")
        return clusters

    async def cluster_summaries(
        self, summaries: list[ConversationSummary]
    ) -> list[Cluster]:
        if not summaries:
            logger.warning("Empty summaries list provided to cluster_summaries")
            return []

        logger.info(
            f"Starting clustering process for {len(summaries)} conversation summaries"
        )

        embeddings = await self._embed_summaries(summaries)
        if not embeddings:
            logger.error(
                "Failed to generate embeddings, cannot proceed with clustering"
            )
            return []

        clusters = await self._generate_clusters_from_embeddings(summaries, embeddings)
        logger.info(
            f"Clustering process completed: generated {len(clusters)} clusters from {len(summaries)} summaries"
        )
        return clusters

    async def _gather_with_progress(
        self,
        tasks,
        desc: str = "Processing",
        disable: bool = False,
        show_preview: bool = False,
    ):
        """Helper method to run async gather with Rich progress bar if available, otherwise tqdm."""
        if self.console and not disable:
            try:
                from rich.progress import (
                    Progress,
                    SpinnerColumn,
                    TextColumn,
                    BarColumn,
                    TaskProgressColumn,
                    TimeRemainingColumn,
                )
                from rich.live import Live
                from rich.layout import Layout
                from rich.panel import Panel
                from rich.text import Text
                from rich.errors import LiveError
            except ImportError:
                return await tqdm_asyncio.gather(*tasks, desc=desc, disable=disable)

            if show_preview:
                # Use Live display with progress and cluster list
                layout = Layout()
                layout.split_column(
                    Layout(name="progress", size=3), Layout(name="clusters")
                )

                all_clusters = []

                # Create progress with cleaner display
                progress = Progress(
                    SpinnerColumn(),
                    TextColumn("[progress.description]{task.description}"),
                    BarColumn(),
                    TaskProgressColumn(),
                    TimeRemainingColumn(),
                    console=self.console,
                )
                task_id = progress.add_task(f"[cyan]{desc}...", total=len(tasks))
                layout["progress"].update(progress)

                try:
                    with Live(layout, console=self.console, refresh_per_second=4):
                        completed_tasks = []
                        for i, task in enumerate(asyncio.as_completed(tasks)):
                            result = await task
                            completed_tasks.append(result)
                            progress.update(task_id, completed=i + 1)

                            # Add to cluster list if it's a Cluster
                            if hasattr(result, "name") and hasattr(
                                result, "description"
                            ):
                                all_clusters.append(result)

                                # Sort clusters by conversation count (largest first)
                                sorted_clusters = sorted(
                                    all_clusters,
                                    key=lambda x: len(x.chat_ids),
                                    reverse=True,
                                )

                                # Create formatted list display
                                cluster_text = Text()
                                for j, cluster in enumerate(sorted_clusters):
                                    cluster_text.append(f"#{j + 1} ", style="bold cyan")
                                    cluster_text.append(
                                        f"{cluster.name}\n", style="bold white"
                                    )
                                    cluster_text.append(
                                        f"    {cluster.description[:120]}...\n",
                                        style="dim white",
                                    )
                                    cluster_text.append(
                                        f"    💬 {len(cluster.chat_ids)} conversations\n\n",
                                        style="dim cyan",
                                    )

                                layout["clusters"].update(
                                    Panel(
                                        cluster_text,
                                        title=f"[green]Generated Clusters ({len(all_clusters)}) - Sorted by Size",
                                        border_style="green",
                                    )
                                )

                        return completed_tasks
                except LiveError:
                    # If Rich Live fails, run silently
                    return await asyncio.gather(*tasks)
            else:
                # Regular progress bar without preview
                try:
                    progress = Progress(
                        SpinnerColumn(),
                        TextColumn("[progress.description]{task.description}"),
                        BarColumn(),
                        TaskProgressColumn(),
                        TimeRemainingColumn(),
                        console=self.console,
                    )

                    with progress:
                        task_id = progress.add_task(
                            f"[cyan]{desc}...", total=len(tasks)
                        )

                        completed_tasks = []
                        for i, task in enumerate(asyncio.as_completed(tasks)):
                            result = await task
                            completed_tasks.append(result)
                            progress.update(task_id, completed=i + 1)

                        return completed_tasks
                except (ImportError, LiveError):
                    # Rich not available or Live error, run silently
                    return await asyncio.gather(*tasks)
        else:
            # No console, run silently
            return await asyncio.gather(*tasks)

checkpoint_filename: str property

The filename to use for checkpointing this model's output.

client = instructor.from_provider(model, async_client=True) instance-attribute

clustering_method = clustering_method instance-attribute

console = console instance-attribute

embedding_model = embedding_model instance-attribute

max_concurrent_requests = max_concurrent_requests instance-attribute

sem = Semaphore(max_concurrent_requests) instance-attribute

__init__(clustering_method: BaseClusteringMethod = KmeansClusteringMethod(), embedding_model: BaseEmbeddingModel = OpenAIEmbeddingModel(), max_concurrent_requests: int = 50, model: str = 'openai/gpt-4o-mini', console: Optional[Console] = None, **kwargs)

Source code in kura/cluster.py
def __init__(
    self,
    clustering_method: BaseClusteringMethod = KmeansClusteringMethod(),
    embedding_model: BaseEmbeddingModel = OpenAIEmbeddingModel(),
    max_concurrent_requests: int = 50,
    model: str = "openai/gpt-4o-mini",
    console: Optional["Console"] = None,
    **kwargs,  # For future use
):
    self.clustering_method = clustering_method
    self.embedding_model = embedding_model
    self.max_concurrent_requests = max_concurrent_requests
    self.sem = Semaphore(max_concurrent_requests)
    self.client = instructor.from_provider(model, async_client=True)
    self.console = console
    logger.info(
        f"Initialized ClusterModel with clustering_method={type(clustering_method).__name__}, embedding_model={type(embedding_model).__name__}, max_concurrent_requests={max_concurrent_requests}, model={model}"
    )

cluster_summaries(summaries: list[ConversationSummary]) -> list[Cluster] async

Source code in kura/cluster.py
async def cluster_summaries(
    self, summaries: list[ConversationSummary]
) -> list[Cluster]:
    if not summaries:
        logger.warning("Empty summaries list provided to cluster_summaries")
        return []

    logger.info(
        f"Starting clustering process for {len(summaries)} conversation summaries"
    )

    embeddings = await self._embed_summaries(summaries)
    if not embeddings:
        logger.error(
            "Failed to generate embeddings, cannot proceed with clustering"
        )
        return []

    clusters = await self._generate_clusters_from_embeddings(summaries, embeddings)
    logger.info(
        f"Clustering process completed: generated {len(clusters)} clusters from {len(summaries)} summaries"
    )
    return clusters

generate_cluster(summaries: list[ConversationSummary], contrastive_examples: list[ConversationSummary]) -> Cluster async

Source code in kura/cluster.py
    async def generate_cluster(
        self,
        summaries: list[ConversationSummary],
        contrastive_examples: list[ConversationSummary],
    ) -> Cluster:
        logger.debug(
            f"Generating cluster from {len(summaries)} summaries with {len(contrastive_examples)} contrastive examples"
        )

        async with self.sem:
            try:
                resp = await self.client.chat.completions.create(
                    messages=[
                        {
                            "role": "system",
                            "content": """
You are tasked with summarizing a group of related statements into a short, precise, and accurate description and name. Your goal is to create a concise summary that captures the essence of these statements and distinguishes them from other similar groups of statements.

Summarize all the statements into a clear, precise, two-sentence description in the past tense. Your summary should be specific to this group and distinguish it from the contrastive answers of the other groups.

After creating the summary, generate a short name for the group of statements. This name should be at most ten words long (perhaps less) and be specific but also reflective of most of the statements (rather than reflecting only one or two).

The name should distinguish this group from the contrastive examples. For instance, "Write fantasy sexual roleplay with octopi and monsters", "Generate blog spam for gambling websites", or "Assist with high school math homework" would be better and more actionable than general terms like "Write erotic content" or "Help with homework". Be as descriptive as possible and assume neither good nor bad faith. Do not hesitate to identify and describe socially harmful or sensitive topics specifically; specificity is necessary for monitoring.

Present your output in the following format:
<summary> [Insert your two-sentence summary here] </summary>
<name> [Insert your generated short name here] </name>

The names you propose must follow these requirements:
- The cluster name should be a sentence in the imperative that captures the user's request. For example, 'Brainstorm ideas for a birthday party' or 'Help me find a new job.'
- Create names that are specific enough to be meaningful, but not so specific that they can't meaningfully represent many different statements.
- Avoid overly general or vague terms, and do not hesitate to describe socially harmful or sensitive topics (in fact, clusters that clearly describe harmful behavior are slightly preferred); specificity is necessary for observability and enforcement.
- Ensure that the cluster name is distinct from the contrastive examples.
- Use clear, concise, and descriptive language for the cluster name.

Below are the related statements:
<positive_examples>
{% for item in positive_examples %}{{ item }}
{% endfor %}
</positive_examples>

For context, here are statements from nearby groups that are NOT part of the group you're summarizing:
<contrastive_examples>
{% for item in contrastive_examples %}{{ item }}
{% endfor %}
</contrastive_examples>

Do not elaborate beyond what you say in the tags. Remember to analyze both the statements and the contrastive statements carefully to ensure your summary and name accurately represent the specific group while distinguishing it from others.
                            """,
                        },
                        {
                            "role": "user",
                            "content": "The cluster name should be a sentence in the imperative that captures the user's request. For example, 'Brainstorm ideas for a birthday party' or 'Help me find a new job.'",
                        },
                        {
                            "role": "assistant",
                            "content": "Sure, I will provide a clear, precise, and accurate summary and name for this cluster. I will be descriptive and assume neither good nor bad faith. Here is the summary, which I will follow with the name:",
                        },
                    ],
                    response_model=GeneratedCluster,
                    context={
                        "positive_examples": summaries,
                        "contrastive_examples": contrastive_examples,
                    },
                )

                cluster = Cluster(
                    name=resp.name,
                    description=resp.summary,
                    slug=resp.slug,
                    chat_ids=[item.chat_id for item in summaries],
                    parent_id=None,
                )

                logger.debug(
                    f"Successfully generated cluster '{resp.name}' with {len(summaries)} conversations"
                )
                return cluster

            except Exception as e:
                logger.error(
                    f"Failed to generate cluster from {len(summaries)} summaries: {e}"
                )
                raise

get_contrastive_examples(cluster_id: int, cluster_id_to_summaries: dict[int, list[ConversationSummary]], limit: int = 10) -> list[ConversationSummary]

Get contrastive examples from other clusters to help distinguish this cluster.

Parameters:

Name Type Description Default
cluster_id int

The id of the cluster to get contrastive examples for

required
cluster_id_to_summaries dict[int, list[ConversationSummary]]

A dictionary of cluster ids to their summaries

required
limit int

The number of contrastive examples to return. Defaults to 10.

10

Returns:

Type Description
list[ConversationSummary]

list[ConversationSummary]: A list of contrastive examples from other clusters

Source code in kura/cluster.py
def get_contrastive_examples(
    self,
    cluster_id: int,
    cluster_id_to_summaries: dict[int, list[ConversationSummary]],
    limit: int = 10,
) -> list[ConversationSummary]:
    """Get contrastive examples from other clusters to help distinguish this cluster.

    Args:
        cluster_id (int): The id of the cluster to get contrastive examples for
        cluster_id_to_summaries (dict[int, list[ConversationSummary]]): A dictionary of cluster ids to their summaries
        limit (int, optional): The number of contrastive examples to return. Defaults to 10.

    Returns:
        list[ConversationSummary]: A list of contrastive examples from other clusters
    """
    other_clusters = [c for c in cluster_id_to_summaries.keys() if c != cluster_id]
    all_examples = []
    for cluster in other_clusters:
        all_examples.extend(cluster_id_to_summaries[cluster])

    logger.debug(
        f"Selecting contrastive examples for cluster {cluster_id}: found {len(all_examples)} examples from {len(other_clusters)} other clusters"
    )

    # If we don't have enough examples, return all of them
    if len(all_examples) <= limit:
        logger.debug(
            f"Using all {len(all_examples)} available contrastive examples (limit was {limit})"
        )
        return all_examples

    # Otherwise sample without replacement
    selected = list(np.random.choice(all_examples, size=limit, replace=False))
    logger.debug(
        f"Randomly selected {len(selected)} contrastive examples from {len(all_examples)} available"
    )
    return selected

Meta-Clustering

kura.meta_cluster

logger = logging.getLogger(__name__) module-attribute

CandidateClusters

Bases: BaseModel

Source code in kura/meta_cluster.py
class CandidateClusters(BaseModel):
    candidate_cluster_names: list[str]

    @field_validator("candidate_cluster_names")
    def validate_candidate_cluster_names(cls, v: list[str]) -> list[str]:
        if len(v) == 0:
            raise ValueError("Candidate cluster names must be a non-empty list")

        v = [label.strip() for label in v]
        v = [label[:-1] if label.endswith(".") else label for label in v]

        return [re.sub(r"\\{1,}", "", label.replace('"', "")) for label in v]

candidate_cluster_names: list[str] instance-attribute

validate_candidate_cluster_names(v: list[str]) -> list[str]

Source code in kura/meta_cluster.py
@field_validator("candidate_cluster_names")
def validate_candidate_cluster_names(cls, v: list[str]) -> list[str]:
    if len(v) == 0:
        raise ValueError("Candidate cluster names must be a non-empty list")

    v = [label.strip() for label in v]
    v = [label[:-1] if label.endswith(".") else label for label in v]

    return [re.sub(r"\\{1,}", "", label.replace('"', "")) for label in v]

ClusterLabel

Bases: BaseModel

Source code in kura/meta_cluster.py
class ClusterLabel(BaseModel):
    higher_level_cluster: str

    @field_validator("higher_level_cluster")
    def validate_higher_level_cluster(cls, v: str, info: ValidationInfo) -> str:
        candidate_clusters = info.context["candidate_clusters"]  # pyright: ignore

        # Exact match check
        if v in candidate_clusters:
            return v

        # Fuzzy match check with 90% similarity threshold
        for candidate in candidate_clusters:
            similarity = fuzz.ratio(v, candidate)
            if similarity >= 90:  # 90% similarity threshold
                return candidate

        # If no match found
        raise ValueError(
            f"""
            Invalid higher-level cluster: |{v}|

            Valid clusters are:
            {", ".join(f"|{c}|" for c in candidate_clusters)}
            """
        )
        return v

higher_level_cluster: str instance-attribute

validate_higher_level_cluster(v: str, info: ValidationInfo) -> str

Source code in kura/meta_cluster.py
@field_validator("higher_level_cluster")
def validate_higher_level_cluster(cls, v: str, info: ValidationInfo) -> str:
    candidate_clusters = info.context["candidate_clusters"]  # pyright: ignore

    # Exact match check
    if v in candidate_clusters:
        return v

    # Fuzzy match check with 90% similarity threshold
    for candidate in candidate_clusters:
        similarity = fuzz.ratio(v, candidate)
        if similarity >= 90:  # 90% similarity threshold
            return candidate

    # If no match found
    raise ValueError(
        f"""
        Invalid higher-level cluster: |{v}|

        Valid clusters are:
        {", ".join(f"|{c}|" for c in candidate_clusters)}
        """
    )
    return v

MetaClusterModel

Bases: BaseMetaClusterModel

Source code in kura/meta_cluster.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
class MetaClusterModel(BaseMetaClusterModel):
    @property
    def checkpoint_filename(self) -> str:
        """The filename to use for checkpointing this model's output."""
        return "meta_clusters.jsonl"

    def __init__(
        self,
        max_concurrent_requests: int = 50,
        model: str = "openai/gpt-4o-mini",
        embedding_model: BaseEmbeddingModel = OpenAIEmbeddingModel(),
        clustering_model: Union[BaseClusteringMethod, None] = None,
        max_clusters: int = 10,
        console: Optional["Console"] = None,
        **kwargs,  # For future use
    ):
        if clustering_model is None:
            from kura.k_means import KmeansClusteringMethod

            clustering_model = KmeansClusteringMethod(12)

        self.max_concurrent_requests = max_concurrent_requests
        self.sem = Semaphore(max_concurrent_requests)
        self.client = instructor.from_provider(model, async_client=True)
        self.console = console
        self.max_clusters = max_clusters

        if embedding_model is None:
            embedding_model = OpenAIEmbeddingModel()

        self.embedding_model = embedding_model
        self.clustering_model = clustering_model
        self.model = model
        self.console = console

        logger.info(
            f"Initialized MetaClusterModel with model={model}, max_concurrent_requests={max_concurrent_requests}, embedding_model={type(embedding_model).__name__}, clustering_model={type(clustering_model).__name__}, max_clusters={max_clusters}"
        )

        # Debug: Check if console is set
        if self.console:
            logger.debug(f"Console is set to {type(self.console)}")
        else:
            logger.debug("Console is None - Rich progress bars will not be available")

    async def _gather_with_progress(
        self,
        tasks,
        desc: str = "Processing",
        disable: bool = False,
        show_preview: bool = False,
    ):
        """Helper method to run async gather with Rich progress bar if available, otherwise tqdm."""
        if self.console and not disable:
            try:
                from rich.progress import (
                    Progress,
                    SpinnerColumn,
                    TextColumn,
                    BarColumn,
                    TaskProgressColumn,
                    TimeRemainingColumn,
                )
                from rich.live import Live
                from rich.layout import Layout
                from rich.panel import Panel
                from rich.text import Text
                from rich.errors import LiveError

                # Check if a Live display is already active by trying to get the current live instance
                try:
                    # Try to access the console's current live instance
                    if (
                        hasattr(self.console, "_live")
                        and self.console._live is not None
                    ):
                        show_preview = (
                            False  # Disable preview if Live is already active
                        )
                except AttributeError:
                    pass  # Console doesn't have _live attribute, that's fine

                if show_preview:
                    # Use Live display with progress and preview buffer
                    layout = Layout()
                    layout.split_column(
                        Layout(name="progress", size=3), Layout(name="preview")
                    )

                    preview_buffer = []
                    max_preview_items = 3

                    # Create progress with cleaner display
                    progress = Progress(
                        SpinnerColumn(),
                        TextColumn("[progress.description]{task.description}"),
                        BarColumn(),
                        TaskProgressColumn(),
                        TimeRemainingColumn(),
                        console=self.console,
                    )
                    task_id = progress.add_task(f"[cyan]{desc}...", total=len(tasks))
                    layout["progress"].update(progress)

                    try:
                        with Live(layout, console=self.console, refresh_per_second=4):
                            completed_tasks = []
                            for i, task in enumerate(asyncio.as_completed(tasks)):
                                result = await task
                                completed_tasks.append(result)
                                progress.update(task_id, completed=i + 1)

                                # Handle different result types
                                if isinstance(result, list):
                                    # For operations that return lists of clusters
                                    for item in result:
                                        if (
                                            hasattr(item, "name")
                                            and hasattr(item, "description")
                                            and item.parent_id is None
                                        ):
                                            preview_buffer.append(item)
                                            if len(preview_buffer) > max_preview_items:
                                                preview_buffer.pop(0)
                                elif hasattr(result, "name") and hasattr(
                                    result, "description"
                                ):
                                    # For operations that return single clusters
                                    preview_buffer.append(result)
                                    if len(preview_buffer) > max_preview_items:
                                        preview_buffer.pop(0)

                                # Update preview display if we have clusters
                                if preview_buffer:
                                    preview_text = Text()
                                    for j, cluster in enumerate(preview_buffer):
                                        preview_text.append(
                                            "Meta Cluster: ", style="bold magenta"
                                        )
                                        preview_text.append(
                                            f"{cluster.name[:80]}...\n",
                                            style="bold white",
                                        )
                                        preview_text.append(
                                            "Description: ", style="bold cyan"
                                        )
                                        preview_text.append(
                                            f"{cluster.description[:100]}...\n\n",
                                            style="dim white",
                                        )

                                    layout["preview"].update(
                                        Panel(
                                            preview_text,
                                            title=f"[magenta]Recent Meta Clusters ({len(preview_buffer)}/{max_preview_items})",
                                            border_style="magenta",
                                        )
                                    )

                            return completed_tasks
                    except LiveError:
                        # If Rich Live fails (e.g., another Live is active), fall back to simple progress
                        with progress:
                            completed_tasks = []
                            for i, task in enumerate(asyncio.as_completed(tasks)):
                                result = await task
                                completed_tasks.append(result)
                                progress.update(task_id, completed=i + 1)
                            return completed_tasks
                else:
                    # Regular progress bar without preview (or when Live is already active)
                    progress = Progress(
                        SpinnerColumn(),
                        TextColumn("[progress.description]{task.description}"),
                        BarColumn(),
                        TaskProgressColumn(),
                        TimeRemainingColumn(),
                        console=self.console,
                    )

                    with progress:
                        task_id = progress.add_task(
                            f"[cyan]{desc}...", total=len(tasks)
                        )

                        completed_tasks = []
                        for i, task in enumerate(asyncio.as_completed(tasks)):
                            result = await task
                            completed_tasks.append(result)
                            progress.update(task_id, completed=i + 1)

                        return completed_tasks

            except (ImportError, LiveError):  # type: ignore
                # Rich not available or Live error, run silently
                return await asyncio.gather(*tasks)
        else:
            # No console, run silently
            return await asyncio.gather(*tasks)

    async def generate_candidate_clusters(
        self, clusters: list[Cluster], sem: Semaphore
    ) -> list[str]:
        async with sem:
            resp = await self.client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": """
                You are tasked with creating higher-level cluster names based on a given list of clusters and their descriptions. Your goal is to come up with broader categories that could encompass one or more of the provided clusters

                First, review the list of clusters and their descriptions:
                <cluster_list>
                    {% for cluster in clusters %}
                    <cluster>{{ cluster.name }}: {{ cluster.description }}</cluster>
                    {% endfor %}
                </cluster_list>

                Your task is to create at most {{ desired_number }} higher-level cluster names that could potentially include one or more of the provided clusters. These higher-level clusters should represent broader categories or themes that emerge from the given clusters, while remaining as specific as possible. If there are many clusters with a specific theme, ensure that the higher-level cluster name remains the maximum level of specificity. You are helping to organize user behavior data in order to improve safety, monitoring, and observability. You can generate less than {{ desired_number }} names if you feel that fewer are appropriate and accurately capture the clusters.

                Guidelines for creating higher-level clusters names
                1. Analyze the themes, topics or characteristics common to multiple clusters.
                2. Create names that are specific enough to be meaningful but but not so specific that they can't meaningfully represent many different clusters. Avoid overly general or vague terms, and do not hesitate to describe socially harmful or sensitive topics (in fact, clusters that clearly describe harmful behavior are slightly preferred); specificity is necessary for observability and enforcement.
                3. Ensure that the higher-level cluster names are distinct from one another.
                4. Use clear, concise, and descriptive language for the cluster names. Assume neither good nor bad faith for the content in the clusters.

                Think about the relationships between the given clusters and potential overarching themes.

                Focus on creating meaningful, distinct and precise ( but not overly specific ) higher-level cluster names that could encompass multiple sub-clusters.
                """.strip(),
                    },
                ],
                response_model=CandidateClusters,
                context={
                    "clusters": clusters,
                    "desired_number": math.ceil(len(clusters) / 2)
                    if len(clusters)
                    >= 3  # If we have two clusters we just merge them tbh
                    else 1,
                },
                max_retries=3,
            )
            return resp.candidate_cluster_names

    async def label_cluster(self, cluster: Cluster, candidate_clusters: list[str]):
        async with self.sem:
            resp = await self.client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": """
You are tasked with categorizing a specific cluster into one of the provided higher-level clusters for observability, monitoring, and content moderation. Your goal is to determine which higher-level cluster best fits the given specific cluster based on its name and description.

First, here are the ONLY valid higher-level clusters you may select from:
<higher_level_clusters>
{% for cluster in candidate_clusters %}
<higher_level_cluster>{{ cluster }}</higher_level_cluster>
{% endfor %}
</higher_level_clusters>

Here is the specific cluster to categorize:
<specific_cluster>
Name: {{ cluster.name }}
Description: {{ cluster.description }}
</specific_cluster>

RULES:
1. You MUST select EXACTLY ONE higher-level cluster from the provided list
2. You MUST output the higher-level cluster name EXACTLY as written - no modifications allowed
3. You MUST NOT create new cluster names or combinations
4. You MUST NOT output any additional text or explanations
5. You MUST NOT use partial matches or approximate names

CLASSIFICATION PROCESS:
1. First, record the exact list of valid higher-level clusters
2. Read the specific cluster's name and description carefully
3. Compare the specific cluster's key characteristics against each valid higher-level cluster
4. Select the single most appropriate higher-level cluster that encompasses the specific cluster
5. Verify your selected cluster exactly matches one from the valid list
6. Output ONLY the selected higher-level cluster name, exactly as it appeared in the valid list

Here is the specific cluster to categorize:

<specific_cluster>
Name: {{ cluster.name }}
Description: {{ cluster.description }}
</specific_cluster>

Based on this information, determine the most appropriate higher-level cluster and provide your answer as instructed.
                        """,
                    }
                ],
                response_model=ClusterLabel,
                context={
                    "cluster": cluster,
                    "candidate_clusters": candidate_clusters,
                },
                max_retries=3,
            )
            return {
                "cluster": cluster,
                "label": resp.higher_level_cluster,
            }

    async def rename_cluster_group(self, clusters: list[Cluster]) -> list[Cluster]:
        async with self.sem:
            resp = await self.client.chat.completions.create(
                messages=[
                    {
                        "role": "system",
                        "content": """
                        You are tasked with summarizing a group of related cluster names into a short, precise, and accurate overall description and name. Your goal is to create a concise summary that captures the essence of these clusters

                        The cluster name should be at most ten words long (perhaps less) and be specific but also reflective of most of the clusters that comprise them. The cluster name should be a sentence in the imperative that captures the user's request. For example, 'Brainstorm ideas for a birthday party' or 'Help me find a new job are good examples.

                        Do not shy away from describing socially harmful or sensitive topics. For instance, "Write fantasy sexual roleplay with octopi and monsters", "Generate blog spam for gambling websites", or "Assist with high school math homework" would be better than "Write erotic content" or "Help with homework". Be as descriptive as possible and assume neither good nor bad faith. Do not hesitate to identify and describe socially harmful or sensitive topics specifically; specificity is necessary for monitoring.

                        Ensure your summary and name accurately represent the clusters and are specific to the clusters.
                        """,
                    },
                    {
                        "role": "user",
                        "content": """
                        Here are the related cluster names
                        <clusters>
                            {% for cluster in clusters %}
                                <cluster>{{ cluster.name }}: {{ cluster.description }}</cluster>
                            {% endfor %}
                        </clusters>
                        """,
                    },
                ],
                context={"clusters": clusters},
                response_model=GeneratedCluster,
            )

            res = []

            new_cluster = Cluster(
                name=resp.name,
                description=resp.summary,
                slug=resp.slug,
                chat_ids=[
                    chat_id for cluster in clusters for chat_id in cluster.chat_ids
                ],
                parent_id=None,
            )

            res.append(new_cluster)

            for cluster in clusters:
                res.append(
                    Cluster(
                        id=cluster.id,
                        name=cluster.name,
                        description=cluster.description,
                        slug=cluster.slug,
                        chat_ids=cluster.chat_ids,
                        parent_id=new_cluster.id,
                    )
                )

            return res

    async def generate_meta_clusters(
        self, clusters: list[Cluster], show_preview: bool = True
    ) -> list[Cluster]:
        # Use a single Live display for the entire meta clustering operation
        if self.console and show_preview:
            try:
                from rich.progress import (
                    Progress,
                    SpinnerColumn,
                    TextColumn,
                    BarColumn,
                    TaskProgressColumn,
                    TimeRemainingColumn,
                )
                from rich.live import Live
                from rich.layout import Layout
                from rich.panel import Panel
                from rich.text import Text
                from rich.errors import LiveError

                # Create layout for the entire meta clustering operation
                layout = Layout()
                layout.split_column(
                    Layout(
                        name="progress", size=6
                    ),  # More space for multiple progress bars
                    Layout(name="preview"),
                )

                # Create progress display
                progress = Progress(
                    SpinnerColumn(),
                    TextColumn("[progress.description]{task.description}"),
                    BarColumn(),
                    TaskProgressColumn(),
                    TimeRemainingColumn(),
                    console=self.console,
                )
                layout["progress"].update(progress)

                preview_buffer = []
                max_preview_items = 3

                try:
                    with Live(layout, console=self.console, refresh_per_second=4):
                        # Step 1: Generate candidate clusters
                        candidate_labels = await self.generate_candidate_clusters(
                            clusters, Semaphore(self.max_concurrent_requests)
                        )

                        # Step 2: Label clusters with progress
                        label_task_id = progress.add_task(
                            "[cyan]Labeling clusters...", total=len(clusters)
                        )
                        cluster_labels = []
                        for i, cluster in enumerate(clusters):
                            result = await self.label_cluster(cluster, candidate_labels)
                            cluster_labels.append(result)
                            progress.update(label_task_id, completed=i + 1)

                        # Group clusters by label
                        label_to_clusters = {}
                        for label in cluster_labels:
                            if label["label"] not in label_to_clusters:
                                label_to_clusters[label["label"]] = []
                            label_to_clusters[label["label"]].append(label["cluster"])

                        # Step 3: Rename cluster groups with progress and preview
                        rename_task_id = progress.add_task(
                            "[cyan]Renaming cluster groups...",
                            total=len(label_to_clusters),
                        )
                        new_clusters = []
                        for i, cluster_group in enumerate(label_to_clusters.values()):
                            result = await self.rename_cluster_group(cluster_group)
                            new_clusters.append(result)
                            progress.update(rename_task_id, completed=i + 1)

                            # Update preview with new meta clusters
                            for cluster in result:
                                if (
                                    hasattr(cluster, "name")
                                    and hasattr(cluster, "description")
                                    and cluster.parent_id is None
                                ):
                                    preview_buffer.append(cluster)
                                    if len(preview_buffer) > max_preview_items:
                                        preview_buffer.pop(0)

                            # Update preview display
                            if preview_buffer:
                                preview_text = Text()
                                for j, cluster in enumerate(preview_buffer):
                                    preview_text.append(
                                        "Meta Cluster: ", style="bold magenta"
                                    )
                                    preview_text.append(
                                        f"{cluster.name[:80]}...\n", style="bold white"
                                    )
                                    preview_text.append(
                                        "Description: ", style="bold cyan"
                                    )
                                    preview_text.append(
                                        f"{cluster.description[:100]}...\n\n",
                                        style="dim white",
                                    )

                                layout["preview"].update(
                                    Panel(
                                        preview_text,
                                        title=f"[magenta]Recent Meta Clusters ({len(preview_buffer)}/{max_preview_items})",
                                        border_style="magenta",
                                    )
                                )

                        # Flatten results
                        res = []
                        for new_cluster in new_clusters:
                            res.extend(new_cluster)

                        return res

                except LiveError:
                    # Fall back to the original method without Live display
                    return await self._generate_meta_clusters_fallback(clusters)

            except ImportError:
                # Rich not available, fall back
                return await self._generate_meta_clusters_fallback(clusters)
        else:
            # No console or preview disabled, use original method
            return await self._generate_meta_clusters_fallback(clusters)

    async def _generate_meta_clusters_fallback(
        self, clusters: list[Cluster]
    ) -> list[Cluster]:
        """Fallback method for generate_meta_clusters when Live display is not available"""
        candidate_labels = await self.generate_candidate_clusters(
            clusters, Semaphore(self.max_concurrent_requests)
        )

        cluster_labels = await self._gather_with_progress(
            [self.label_cluster(cluster, candidate_labels) for cluster in clusters],
            desc="Labeling clusters",
            disable=False,
            show_preview=False,  # Disable preview to avoid nested Live displays
        )

        label_to_clusters = {}
        for label in cluster_labels:
            if label["label"] not in label_to_clusters:
                label_to_clusters[label["label"]] = []

            label_to_clusters[label["label"]].append(label["cluster"])

        new_clusters = await self._gather_with_progress(
            [
                self.rename_cluster_group(cluster)
                for cluster in label_to_clusters.values()
            ],
            desc="Renaming cluster groups",
            show_preview=False,  # Disable preview to avoid nested Live displays
        )

        res = []
        for new_cluster in new_clusters:
            res.extend(new_cluster)

        return res

    async def reduce_clusters(self, clusters: list[Cluster]) -> list[Cluster]:
        """
        This takes in a list of existing clusters and generates a few higher order clusters that are more general. This represents a single iteration of the meta clustering process.

        In the event that we have a single cluster, we will just return a new higher level cluster which has the same name as the original cluster. ( This is an edge case which we should definitely handle better )
        """
        if not clusters:
            return []

        if len(clusters) == 1:
            logger.info("Only one cluster, returning it as a meta cluster")
            new_cluster = Cluster(
                name=clusters[0].name,
                description=clusters[0].description,
                slug=clusters[0].slug,
                chat_ids=clusters[0].chat_ids,
                parent_id=None,
            )
            return [new_cluster, clusters[0]]

        texts_to_embed = [str(cluster) for cluster in clusters]

        logger.info(
            f"Embedding {len(texts_to_embed)} clusters for meta-clustering using {type(self.embedding_model).__name__}..."
        )

        cluster_embeddings = await self.embedding_model.embed(texts_to_embed)

        if not cluster_embeddings or len(cluster_embeddings) != len(clusters):
            logger.error(
                "Error: Number of embeddings does not match number of clusters or embeddings are empty for meta-clustering."
            )
            return []

        clusters_and_embeddings = [
            {
                "item": cluster,
                "embedding": embedding,
            }
            for cluster, embedding in zip(clusters, cluster_embeddings)
        ]

        cluster_id_to_clusters: dict[int, list[Cluster]] = (
            self.clustering_model.cluster(clusters_and_embeddings)
        )  # type: ignore

        new_clusters = await self._gather_with_progress(
            [
                self.generate_meta_clusters(
                    cluster_id_to_clusters[cluster_id], show_preview=True
                )
                for cluster_id in cluster_id_to_clusters
            ],
            desc="Generating Meta Clusters",
            show_preview=True,
        )

        res = []
        for new_cluster in new_clusters:
            res.extend(new_cluster)

        return res

checkpoint_filename: str property

The filename to use for checkpointing this model's output.

client = instructor.from_provider(model, async_client=True) instance-attribute

clustering_model = clustering_model instance-attribute

console = console instance-attribute

embedding_model = embedding_model instance-attribute

max_clusters = max_clusters instance-attribute

max_concurrent_requests = max_concurrent_requests instance-attribute

model = model instance-attribute

sem = Semaphore(max_concurrent_requests) instance-attribute

__init__(max_concurrent_requests: int = 50, model: str = 'openai/gpt-4o-mini', embedding_model: BaseEmbeddingModel = OpenAIEmbeddingModel(), clustering_model: Union[BaseClusteringMethod, None] = None, max_clusters: int = 10, console: Optional['Console'] = None, **kwargs)

Source code in kura/meta_cluster.py
def __init__(
    self,
    max_concurrent_requests: int = 50,
    model: str = "openai/gpt-4o-mini",
    embedding_model: BaseEmbeddingModel = OpenAIEmbeddingModel(),
    clustering_model: Union[BaseClusteringMethod, None] = None,
    max_clusters: int = 10,
    console: Optional["Console"] = None,
    **kwargs,  # For future use
):
    if clustering_model is None:
        from kura.k_means import KmeansClusteringMethod

        clustering_model = KmeansClusteringMethod(12)

    self.max_concurrent_requests = max_concurrent_requests
    self.sem = Semaphore(max_concurrent_requests)
    self.client = instructor.from_provider(model, async_client=True)
    self.console = console
    self.max_clusters = max_clusters

    if embedding_model is None:
        embedding_model = OpenAIEmbeddingModel()

    self.embedding_model = embedding_model
    self.clustering_model = clustering_model
    self.model = model
    self.console = console

    logger.info(
        f"Initialized MetaClusterModel with model={model}, max_concurrent_requests={max_concurrent_requests}, embedding_model={type(embedding_model).__name__}, clustering_model={type(clustering_model).__name__}, max_clusters={max_clusters}"
    )

    # Debug: Check if console is set
    if self.console:
        logger.debug(f"Console is set to {type(self.console)}")
    else:
        logger.debug("Console is None - Rich progress bars will not be available")

generate_candidate_clusters(clusters: list[Cluster], sem: Semaphore) -> list[str] async

Source code in kura/meta_cluster.py
async def generate_candidate_clusters(
    self, clusters: list[Cluster], sem: Semaphore
) -> list[str]:
    async with sem:
        resp = await self.client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": """
            You are tasked with creating higher-level cluster names based on a given list of clusters and their descriptions. Your goal is to come up with broader categories that could encompass one or more of the provided clusters

            First, review the list of clusters and their descriptions:
            <cluster_list>
                {% for cluster in clusters %}
                <cluster>{{ cluster.name }}: {{ cluster.description }}</cluster>
                {% endfor %}
            </cluster_list>

            Your task is to create at most {{ desired_number }} higher-level cluster names that could potentially include one or more of the provided clusters. These higher-level clusters should represent broader categories or themes that emerge from the given clusters, while remaining as specific as possible. If there are many clusters with a specific theme, ensure that the higher-level cluster name remains the maximum level of specificity. You are helping to organize user behavior data in order to improve safety, monitoring, and observability. You can generate less than {{ desired_number }} names if you feel that fewer are appropriate and accurately capture the clusters.

            Guidelines for creating higher-level clusters names
            1. Analyze the themes, topics or characteristics common to multiple clusters.
            2. Create names that are specific enough to be meaningful but but not so specific that they can't meaningfully represent many different clusters. Avoid overly general or vague terms, and do not hesitate to describe socially harmful or sensitive topics (in fact, clusters that clearly describe harmful behavior are slightly preferred); specificity is necessary for observability and enforcement.
            3. Ensure that the higher-level cluster names are distinct from one another.
            4. Use clear, concise, and descriptive language for the cluster names. Assume neither good nor bad faith for the content in the clusters.

            Think about the relationships between the given clusters and potential overarching themes.

            Focus on creating meaningful, distinct and precise ( but not overly specific ) higher-level cluster names that could encompass multiple sub-clusters.
            """.strip(),
                },
            ],
            response_model=CandidateClusters,
            context={
                "clusters": clusters,
                "desired_number": math.ceil(len(clusters) / 2)
                if len(clusters)
                >= 3  # If we have two clusters we just merge them tbh
                else 1,
            },
            max_retries=3,
        )
        return resp.candidate_cluster_names

generate_meta_clusters(clusters: list[Cluster], show_preview: bool = True) -> list[Cluster] async

Source code in kura/meta_cluster.py
async def generate_meta_clusters(
    self, clusters: list[Cluster], show_preview: bool = True
) -> list[Cluster]:
    # Use a single Live display for the entire meta clustering operation
    if self.console and show_preview:
        try:
            from rich.progress import (
                Progress,
                SpinnerColumn,
                TextColumn,
                BarColumn,
                TaskProgressColumn,
                TimeRemainingColumn,
            )
            from rich.live import Live
            from rich.layout import Layout
            from rich.panel import Panel
            from rich.text import Text
            from rich.errors import LiveError

            # Create layout for the entire meta clustering operation
            layout = Layout()
            layout.split_column(
                Layout(
                    name="progress", size=6
                ),  # More space for multiple progress bars
                Layout(name="preview"),
            )

            # Create progress display
            progress = Progress(
                SpinnerColumn(),
                TextColumn("[progress.description]{task.description}"),
                BarColumn(),
                TaskProgressColumn(),
                TimeRemainingColumn(),
                console=self.console,
            )
            layout["progress"].update(progress)

            preview_buffer = []
            max_preview_items = 3

            try:
                with Live(layout, console=self.console, refresh_per_second=4):
                    # Step 1: Generate candidate clusters
                    candidate_labels = await self.generate_candidate_clusters(
                        clusters, Semaphore(self.max_concurrent_requests)
                    )

                    # Step 2: Label clusters with progress
                    label_task_id = progress.add_task(
                        "[cyan]Labeling clusters...", total=len(clusters)
                    )
                    cluster_labels = []
                    for i, cluster in enumerate(clusters):
                        result = await self.label_cluster(cluster, candidate_labels)
                        cluster_labels.append(result)
                        progress.update(label_task_id, completed=i + 1)

                    # Group clusters by label
                    label_to_clusters = {}
                    for label in cluster_labels:
                        if label["label"] not in label_to_clusters:
                            label_to_clusters[label["label"]] = []
                        label_to_clusters[label["label"]].append(label["cluster"])

                    # Step 3: Rename cluster groups with progress and preview
                    rename_task_id = progress.add_task(
                        "[cyan]Renaming cluster groups...",
                        total=len(label_to_clusters),
                    )
                    new_clusters = []
                    for i, cluster_group in enumerate(label_to_clusters.values()):
                        result = await self.rename_cluster_group(cluster_group)
                        new_clusters.append(result)
                        progress.update(rename_task_id, completed=i + 1)

                        # Update preview with new meta clusters
                        for cluster in result:
                            if (
                                hasattr(cluster, "name")
                                and hasattr(cluster, "description")
                                and cluster.parent_id is None
                            ):
                                preview_buffer.append(cluster)
                                if len(preview_buffer) > max_preview_items:
                                    preview_buffer.pop(0)

                        # Update preview display
                        if preview_buffer:
                            preview_text = Text()
                            for j, cluster in enumerate(preview_buffer):
                                preview_text.append(
                                    "Meta Cluster: ", style="bold magenta"
                                )
                                preview_text.append(
                                    f"{cluster.name[:80]}...\n", style="bold white"
                                )
                                preview_text.append(
                                    "Description: ", style="bold cyan"
                                )
                                preview_text.append(
                                    f"{cluster.description[:100]}...\n\n",
                                    style="dim white",
                                )

                            layout["preview"].update(
                                Panel(
                                    preview_text,
                                    title=f"[magenta]Recent Meta Clusters ({len(preview_buffer)}/{max_preview_items})",
                                    border_style="magenta",
                                )
                            )

                    # Flatten results
                    res = []
                    for new_cluster in new_clusters:
                        res.extend(new_cluster)

                    return res

            except LiveError:
                # Fall back to the original method without Live display
                return await self._generate_meta_clusters_fallback(clusters)

        except ImportError:
            # Rich not available, fall back
            return await self._generate_meta_clusters_fallback(clusters)
    else:
        # No console or preview disabled, use original method
        return await self._generate_meta_clusters_fallback(clusters)

label_cluster(cluster: Cluster, candidate_clusters: list[str]) async

Source code in kura/meta_cluster.py
    async def label_cluster(self, cluster: Cluster, candidate_clusters: list[str]):
        async with self.sem:
            resp = await self.client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": """
You are tasked with categorizing a specific cluster into one of the provided higher-level clusters for observability, monitoring, and content moderation. Your goal is to determine which higher-level cluster best fits the given specific cluster based on its name and description.

First, here are the ONLY valid higher-level clusters you may select from:
<higher_level_clusters>
{% for cluster in candidate_clusters %}
<higher_level_cluster>{{ cluster }}</higher_level_cluster>
{% endfor %}
</higher_level_clusters>

Here is the specific cluster to categorize:
<specific_cluster>
Name: {{ cluster.name }}
Description: {{ cluster.description }}
</specific_cluster>

RULES:
1. You MUST select EXACTLY ONE higher-level cluster from the provided list
2. You MUST output the higher-level cluster name EXACTLY as written - no modifications allowed
3. You MUST NOT create new cluster names or combinations
4. You MUST NOT output any additional text or explanations
5. You MUST NOT use partial matches or approximate names

CLASSIFICATION PROCESS:
1. First, record the exact list of valid higher-level clusters
2. Read the specific cluster's name and description carefully
3. Compare the specific cluster's key characteristics against each valid higher-level cluster
4. Select the single most appropriate higher-level cluster that encompasses the specific cluster
5. Verify your selected cluster exactly matches one from the valid list
6. Output ONLY the selected higher-level cluster name, exactly as it appeared in the valid list

Here is the specific cluster to categorize:

<specific_cluster>
Name: {{ cluster.name }}
Description: {{ cluster.description }}
</specific_cluster>

Based on this information, determine the most appropriate higher-level cluster and provide your answer as instructed.
                        """,
                    }
                ],
                response_model=ClusterLabel,
                context={
                    "cluster": cluster,
                    "candidate_clusters": candidate_clusters,
                },
                max_retries=3,
            )
            return {
                "cluster": cluster,
                "label": resp.higher_level_cluster,
            }

reduce_clusters(clusters: list[Cluster]) -> list[Cluster] async

This takes in a list of existing clusters and generates a few higher order clusters that are more general. This represents a single iteration of the meta clustering process.

In the event that we have a single cluster, we will just return a new higher level cluster which has the same name as the original cluster. ( This is an edge case which we should definitely handle better )

Source code in kura/meta_cluster.py
async def reduce_clusters(self, clusters: list[Cluster]) -> list[Cluster]:
    """
    This takes in a list of existing clusters and generates a few higher order clusters that are more general. This represents a single iteration of the meta clustering process.

    In the event that we have a single cluster, we will just return a new higher level cluster which has the same name as the original cluster. ( This is an edge case which we should definitely handle better )
    """
    if not clusters:
        return []

    if len(clusters) == 1:
        logger.info("Only one cluster, returning it as a meta cluster")
        new_cluster = Cluster(
            name=clusters[0].name,
            description=clusters[0].description,
            slug=clusters[0].slug,
            chat_ids=clusters[0].chat_ids,
            parent_id=None,
        )
        return [new_cluster, clusters[0]]

    texts_to_embed = [str(cluster) for cluster in clusters]

    logger.info(
        f"Embedding {len(texts_to_embed)} clusters for meta-clustering using {type(self.embedding_model).__name__}..."
    )

    cluster_embeddings = await self.embedding_model.embed(texts_to_embed)

    if not cluster_embeddings or len(cluster_embeddings) != len(clusters):
        logger.error(
            "Error: Number of embeddings does not match number of clusters or embeddings are empty for meta-clustering."
        )
        return []

    clusters_and_embeddings = [
        {
            "item": cluster,
            "embedding": embedding,
        }
        for cluster, embedding in zip(clusters, cluster_embeddings)
    ]

    cluster_id_to_clusters: dict[int, list[Cluster]] = (
        self.clustering_model.cluster(clusters_and_embeddings)
    )  # type: ignore

    new_clusters = await self._gather_with_progress(
        [
            self.generate_meta_clusters(
                cluster_id_to_clusters[cluster_id], show_preview=True
            )
            for cluster_id in cluster_id_to_clusters
        ],
        desc="Generating Meta Clusters",
        show_preview=True,
    )

    res = []
    for new_cluster in new_clusters:
        res.extend(new_cluster)

    return res

rename_cluster_group(clusters: list[Cluster]) -> list[Cluster] async

Source code in kura/meta_cluster.py
async def rename_cluster_group(self, clusters: list[Cluster]) -> list[Cluster]:
    async with self.sem:
        resp = await self.client.chat.completions.create(
            messages=[
                {
                    "role": "system",
                    "content": """
                    You are tasked with summarizing a group of related cluster names into a short, precise, and accurate overall description and name. Your goal is to create a concise summary that captures the essence of these clusters

                    The cluster name should be at most ten words long (perhaps less) and be specific but also reflective of most of the clusters that comprise them. The cluster name should be a sentence in the imperative that captures the user's request. For example, 'Brainstorm ideas for a birthday party' or 'Help me find a new job are good examples.

                    Do not shy away from describing socially harmful or sensitive topics. For instance, "Write fantasy sexual roleplay with octopi and monsters", "Generate blog spam for gambling websites", or "Assist with high school math homework" would be better than "Write erotic content" or "Help with homework". Be as descriptive as possible and assume neither good nor bad faith. Do not hesitate to identify and describe socially harmful or sensitive topics specifically; specificity is necessary for monitoring.

                    Ensure your summary and name accurately represent the clusters and are specific to the clusters.
                    """,
                },
                {
                    "role": "user",
                    "content": """
                    Here are the related cluster names
                    <clusters>
                        {% for cluster in clusters %}
                            <cluster>{{ cluster.name }}: {{ cluster.description }}</cluster>
                        {% endfor %}
                    </clusters>
                    """,
                },
            ],
            context={"clusters": clusters},
            response_model=GeneratedCluster,
        )

        res = []

        new_cluster = Cluster(
            name=resp.name,
            description=resp.summary,
            slug=resp.slug,
            chat_ids=[
                chat_id for cluster in clusters for chat_id in cluster.chat_ids
            ],
            parent_id=None,
        )

        res.append(new_cluster)

        for cluster in clusters:
            res.append(
                Cluster(
                    id=cluster.id,
                    name=cluster.name,
                    description=cluster.description,
                    slug=cluster.slug,
                    chat_ids=cluster.chat_ids,
                    parent_id=new_cluster.id,
                )
            )

        return res

Dimensionality Reduction

kura.dimensionality

logger = logging.getLogger(__name__) module-attribute

HDBUMAP

Bases: BaseDimensionalityReduction

Source code in kura/dimensionality.py
class HDBUMAP(BaseDimensionalityReduction):
    @property
    def checkpoint_filename(self) -> str:
        """The filename to use for checkpointing this model's output."""
        return "dimensionality.jsonl"

    def __init__(
        self,
        embedding_model: BaseEmbeddingModel = OpenAIEmbeddingModel(),
        n_components: int = 2,
        min_dist: float = 0.1,
        metric: str = "cosine",
        n_neighbors: Union[int, None] = None,
    ):
        self.embedding_model = embedding_model
        self.n_components = n_components
        self.min_dist = min_dist
        self.metric = metric
        self.n_neighbors = n_neighbors
        logger.info(
            f"Initialized HDBUMAP with embedding_model={type(embedding_model).__name__}, n_components={n_components}, min_dist={min_dist}, metric={metric}, n_neighbors={n_neighbors}"
        )

    async def reduce_dimensionality(
        self, clusters: list[Cluster]
    ) -> list[ProjectedCluster]:
        # Embed all clusters
        from umap import UMAP

        if not clusters:
            logger.warning("Empty clusters list provided to reduce_dimensionality")
            return []

        logger.info(f"Starting dimensionality reduction for {len(clusters)} clusters")
        texts_to_embed = [str(c) for c in clusters]

        try:
            cluster_embeddings = await self.embedding_model.embed(texts_to_embed)
            logger.debug(f"Generated embeddings for {len(clusters)} clusters")
        except Exception as e:
            logger.error(f"Failed to generate embeddings for clusters: {e}")
            raise

        if not cluster_embeddings or len(cluster_embeddings) != len(texts_to_embed):
            logger.error(
                f"Error: Number of embeddings ({len(cluster_embeddings) if cluster_embeddings else 0}) does not match number of clusters ({len(texts_to_embed)}) or embeddings are empty."
            )
            return []

        embeddings = np.array(cluster_embeddings)
        logger.debug(f"Created embedding matrix of shape {embeddings.shape}")

        # Project to 2D using UMAP
        n_neighbors_actual = (
            self.n_neighbors if self.n_neighbors else min(15, len(embeddings) - 1)
        )
        logger.debug(
            f"Using UMAP with n_neighbors={n_neighbors_actual}, min_dist={self.min_dist}, metric={self.metric}"
        )

        try:
            umap_reducer = UMAP(
                n_components=self.n_components,
                n_neighbors=n_neighbors_actual,
                min_dist=self.min_dist,
                metric=self.metric,
            )
            reduced_embeddings = umap_reducer.fit_transform(embeddings)
            logger.info(
                f"UMAP dimensionality reduction completed: {embeddings.shape} -> {reduced_embeddings.shape}"  # type: ignore
            )
        except Exception as e:
            logger.error(f"UMAP dimensionality reduction failed: {e}")
            raise

        # Create projected clusters with 2D coordinates
        res = []
        for i, cluster in enumerate(clusters):
            projected = ProjectedCluster(
                slug=cluster.slug,
                id=cluster.id,
                name=cluster.name,
                description=cluster.description,
                chat_ids=cluster.chat_ids,
                parent_id=cluster.parent_id,
                x_coord=float(reduced_embeddings[i][0]),  # pyright: ignore
                y_coord=float(reduced_embeddings[i][1]),  # pyright: ignore
                level=0
                if cluster.parent_id is None
                else 1,  # TODO: Fix this, should reflect the level of the cluster
            )
            res.append(projected)

        logger.info(f"Successfully created {len(res)} projected clusters")
        return res

checkpoint_filename: str property

The filename to use for checkpointing this model's output.

embedding_model = embedding_model instance-attribute

metric = metric instance-attribute

min_dist = min_dist instance-attribute

n_components = n_components instance-attribute

n_neighbors = n_neighbors instance-attribute

__init__(embedding_model: BaseEmbeddingModel = OpenAIEmbeddingModel(), n_components: int = 2, min_dist: float = 0.1, metric: str = 'cosine', n_neighbors: Union[int, None] = None)

Source code in kura/dimensionality.py
def __init__(
    self,
    embedding_model: BaseEmbeddingModel = OpenAIEmbeddingModel(),
    n_components: int = 2,
    min_dist: float = 0.1,
    metric: str = "cosine",
    n_neighbors: Union[int, None] = None,
):
    self.embedding_model = embedding_model
    self.n_components = n_components
    self.min_dist = min_dist
    self.metric = metric
    self.n_neighbors = n_neighbors
    logger.info(
        f"Initialized HDBUMAP with embedding_model={type(embedding_model).__name__}, n_components={n_components}, min_dist={min_dist}, metric={metric}, n_neighbors={n_neighbors}"
    )

reduce_dimensionality(clusters: list[Cluster]) -> list[ProjectedCluster] async

Source code in kura/dimensionality.py
async def reduce_dimensionality(
    self, clusters: list[Cluster]
) -> list[ProjectedCluster]:
    # Embed all clusters
    from umap import UMAP

    if not clusters:
        logger.warning("Empty clusters list provided to reduce_dimensionality")
        return []

    logger.info(f"Starting dimensionality reduction for {len(clusters)} clusters")
    texts_to_embed = [str(c) for c in clusters]

    try:
        cluster_embeddings = await self.embedding_model.embed(texts_to_embed)
        logger.debug(f"Generated embeddings for {len(clusters)} clusters")
    except Exception as e:
        logger.error(f"Failed to generate embeddings for clusters: {e}")
        raise

    if not cluster_embeddings or len(cluster_embeddings) != len(texts_to_embed):
        logger.error(
            f"Error: Number of embeddings ({len(cluster_embeddings) if cluster_embeddings else 0}) does not match number of clusters ({len(texts_to_embed)}) or embeddings are empty."
        )
        return []

    embeddings = np.array(cluster_embeddings)
    logger.debug(f"Created embedding matrix of shape {embeddings.shape}")

    # Project to 2D using UMAP
    n_neighbors_actual = (
        self.n_neighbors if self.n_neighbors else min(15, len(embeddings) - 1)
    )
    logger.debug(
        f"Using UMAP with n_neighbors={n_neighbors_actual}, min_dist={self.min_dist}, metric={self.metric}"
    )

    try:
        umap_reducer = UMAP(
            n_components=self.n_components,
            n_neighbors=n_neighbors_actual,
            min_dist=self.min_dist,
            metric=self.metric,
        )
        reduced_embeddings = umap_reducer.fit_transform(embeddings)
        logger.info(
            f"UMAP dimensionality reduction completed: {embeddings.shape} -> {reduced_embeddings.shape}"  # type: ignore
        )
    except Exception as e:
        logger.error(f"UMAP dimensionality reduction failed: {e}")
        raise

    # Create projected clusters with 2D coordinates
    res = []
    for i, cluster in enumerate(clusters):
        projected = ProjectedCluster(
            slug=cluster.slug,
            id=cluster.id,
            name=cluster.name,
            description=cluster.description,
            chat_ids=cluster.chat_ids,
            parent_id=cluster.parent_id,
            x_coord=float(reduced_embeddings[i][0]),  # pyright: ignore
            y_coord=float(reduced_embeddings[i][1]),  # pyright: ignore
            level=0
            if cluster.parent_id is None
            else 1,  # TODO: Fix this, should reflect the level of the cluster
        )
        res.append(projected)

    logger.info(f"Successfully created {len(res)} projected clusters")
    return res