Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ profile = "black"

[project]
name = "turftopic"
version = "0.26.0"
version = "0.26.1"
description = "Topic modeling with contextual representations from sentence transformers."
authors = [
{ name = "Márton Kardos <power.up1163@gmail.com>", email = "martonkardos@cas.au.dk" }
Expand Down
4 changes: 2 additions & 2 deletions turftopic/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def bin_timestamps(
# Have to substract one, else it starts from one
return np.digitize(unix_timestamps, unix_bins) - 1, bins
else:
# Adding one day, so that the maximum value is still included.
max_timestamp = max(timestamps) + timedelta(days=1)
# Adding one microsecond, so that the maximum value is still included.
max_timestamp = max(timestamps) + timedelta(microseconds=1)
unix_bins = np.histogram_bin_edges(unix_timestamps, bins=bins)
unix_bins[-1] = max_timestamp.timestamp()
bins = [datetime.fromtimestamp(ts) for ts in unix_bins]
Expand Down
5 changes: 3 additions & 2 deletions turftopic/models/_snmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,14 @@ def fit_timeslice(self, X_t: np.ndarray, G_t: np.ndarray):
F = update_F(X_t.T, G_t, F=None)
return F.T

def transform(self, X: np.ndarray):
def transform(self, X: np.ndarray, F=None):
G = init_G(
X.T,
n_components=self.n_components,
random_state=self.random_state,
)
F = self.components_.T
if F is None:
F = self.components_.T
update = jit(lambda G: update_G(X.T, G, F, sparsity=self.sparsity))
error_at_init = rec_err(X.T, F, G)
prev_error = error_at_init
Expand Down
122 changes: 91 additions & 31 deletions turftopic/models/senstopic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timedelta
from functools import partial
from typing import Literal, Optional, Union

Expand Down Expand Up @@ -217,20 +217,46 @@ def update_vocabulary(self, raw_documents):
set(new_vectorizer.get_feature_names_out()) - set(old_vocab)
)
if len(new_vocab) == 0:
return
return []
new_vocab_embeddings = self.encode_documents(new_vocab)
self.vocab_embeddings = np.concatenate(
[self.vocab_embeddings, new_vocab_embeddings], axis=0
)
self.vectorizer.get_feature_names_out = lambda: np.array(
list(old_vocab) + new_vocab
)
return new_vocab

def partial_fit(
self, raw_documents, y=None, embeddings=None, n_new_components="auto"
self,
raw_documents,
y=None,
embeddings=None,
timestamps=None,
n_new_components="auto",
):
if timestamps is not None:
if (getattr(self, "components_", None) is None) or (
getattr(self, "time_bin_edges", None) is None
):
return self.fit_transform_dynamic(
raw_documents,
embeddings=embeddings,
timestamps=timestamps,
bins=1,
)
if getattr(self, "components_", None) is None:
return self.fit(raw_documents, embeddings=embeddings)
if timestamps is None:
return self.fit(raw_documents, embeddings=embeddings)
if timestamps is not None:
last_edge = self.time_bin_edges[-1]
is_before = [(ts <= last_edge) for ts in timestamps]
n_before = np.sum(is_before)
if n_before:
raise ValueError(
"When using partial fitting on a dynamic model, all new documents have to be in a new time slice. "
f"Currently there are {n_before} documents from before {last_edge}. Remove these before fitting."
)
console = Console()
with console.status("Updating model with new data") as status:
if embeddings is None:
Expand All @@ -253,10 +279,11 @@ def partial_fit(
)
self.n_components_ = self.decomposition.n_components
doc_topic = self.decomposition.transform(embeddings)
console.log("Updated model")
console.log(f"Updated model with {n_new_components} topics.")
status.update("Updating vocabulary")
self.update_vocabulary(raw_documents)
console.log("Updated vocabulary")
new_vocab = self.update_vocabulary(raw_documents)
n_new_vocab = len(new_vocab)
console.log(f"Updated vocabulary with {n_new_vocab} items.")
status.update("Estimating term importances")
vocab_topic = self.decomposition.transform(self.vocab_embeddings)
self.axial_components_ = vocab_topic.T
Expand All @@ -279,13 +306,41 @@ def partial_fit(
*self.topic_names[-n_new_components:],
]
console.log("Updated term importances")
self.top_documents.extend(
self.get_top_documents(
raw_documents,
document_topic_matrix=doc_topic[:, -n_new_components:],
for new_dt in doc_topic[:, -n_new_components:].T:
top = np.argsort(-new_dt)
self.top_documents.append(
[raw_documents[i_top] for i_top in top]
)
)
self.document_topic_matrix = doc_topic
if timestamps is not None:
status.update("Updating temporal components.")
self.time_bin_edges.append(
max(timestamps) + timedelta(microseconds=1)
)
t_components = []
t_importance = []
for t_component, t_imp in zip(
self.axial_temporal_components_, self.temporal_importance_
):
t_component = np.pad(
t_component,
[(0, n_new_components), (0, n_new_vocab)],
mode="constant",
constant_values=0,
)
t_imp = np.pad(
t_imp,
(0, n_new_components),
mode="constant",
constant_values=0,
)
t_components.append(t_component)
t_importance.append(t_imp)
new_imp, new_comp = self._fit_timebin(embeddings, doc_topic)
t_components.append(new_comp)
t_importance.append(new_imp)
self.axial_temporal_components_ = np.stack(t_components)
self.temporal_importance_ = np.stack(t_importance)
self.estimate_components(self.feature_importance)
console.log("Model update done.")
return self

Expand Down Expand Up @@ -373,16 +428,28 @@ def fit_transform_multimodal(
console.log("Images transformed")
return doc_topic

def _fit_timebin(self, t_X, t_dt):
t_imp = t_dt.mean(axis=0)
t_F = self.decomposition.fit_timeslice(t_X, t_dt).T
t_G = self.decomposition.transform(self.vocab_embeddings, F=t_F)
t_components_ = t_G.T
return t_imp, t_components_

def fit_transform_dynamic(
self,
raw_documents,
timestamps: list[datetime],
embeddings: Optional[np.ndarray] = None,
bins: Union[int, list[datetime]] = 10,
) -> np.ndarray:
document_topic_matrix = self.fit_transform(
raw_documents, embeddings=embeddings
)
if getattr(self, "components_", None) is None:
document_topic_matrix = self.fit_transform(
raw_documents, embeddings=embeddings
)
else:
document_topic_matrix = self.transform(
raw_documents, embeddings=embeddings
)
time_labels, self.time_bin_edges = self.bin_timestamps(
timestamps, bins
)
Expand All @@ -394,22 +461,15 @@ def fit_transform_dynamic(
dtype=self.components_.dtype,
)
self.temporal_importance_ = np.zeros((n_bins, n_comp))
# doc_topic = np.dot(X, self.components_.T)
for i_timebin in np.unique(time_labels):
topic_importances = document_topic_matrix[
time_labels == i_timebin
].mean(axis=0)
self.temporal_importance_[i_timebin, :] = topic_importances
t_doc_topic = document_topic_matrix[time_labels == i_timebin]
t_embeddings = self.embeddings[time_labels == i_timebin]
t_components = self.decomposition.fit_timeslice(
t_embeddings, t_doc_topic
)
ax_t = np.maximum(
self.vocab_embeddings @ np.linalg.pinv(t_components), 0
)
self.axial_temporal_components_[i_timebin, :, :] = ax_t.T
self.estimate_components(self.feature_importance)
t_dt = document_topic_matrix[time_labels == i_timebin]
t_X = self.embeddings[time_labels == i_timebin]
t_imp, t_comp = self._fit_timebin(t_X, t_dt)
self.temporal_importance_[i_timebin, :] = t_imp
self.axial_temporal_components_[i_timebin, :, :] = t_comp
self.estimate_components(
self.feature_importance,
)
return document_topic_matrix

@property
Expand Down
Loading