import logging
import uuid
from collections.abc import Generator, Iterable, Sequence
from datetime import datetime
from typing import Any, cast
from django.core.exceptions import FieldError, ImproperlyConfigured, ObjectDoesNotExist
from django.db.models import FileField, Manager, Model
from django.db.models.query import QuerySet
from django.utils import timezone
from kafka import KafkaProducer # type: ignore
from kafka.errors import KafkaTimeoutError, NoBrokersAvailable # type: ignore
from .constants import TYPE_DELETE, TYPE_ENUMERATE, TYPE_EOS, TYPE_REFRESH
from .context import _context
from .settings import get_setting
from .types import (
Message,
MessageContext,
MessageMeta,
MessageSerializer,
ObjectID,
Partitioner,
PartitionKeySerializer,
RefreshFinalizeType,
)
log = logging.getLogger(__name__)
class Batch:
"""
Represents batch operation
"""
def __init__(
self,
objects: Iterable[Model] | None = None,
queryset: QuerySet | None = None,
manager: Manager | None = None,
objects_ids: Sequence[ObjectID] | None = None,
select_related: Sequence[str] | None = None,
prefetch_related: Sequence[str] | None = None,
**kwargs: Any,
):
self.objects = objects
self.queryset = queryset
self.manager = manager
self.objects_ids = objects_ids
self.select_related = select_related
self.prefetch_related = prefetch_related
assert (
self.manager is not None
or self.queryset is not None
or self.objects is not None
)
self.model: type[Model] | None = None
if self.manager is not None:
self.model = self.manager.model
elif self.queryset is not None:
self.model = self.queryset.model
def get_objects(self) -> QuerySet | Iterable[Model]:
queryset = self.queryset
if queryset is None and self.manager is not None:
queryset = self.manager.all()
if queryset is not None and self.objects_ids is not None:
queryset = queryset.filter(pk__in=self.objects_ids).order_by()
if self.select_related:
queryset = queryset.select_related(*self.select_related)
if self.prefetch_related:
queryset = queryset.prefetch_related(*self.prefetch_related)
return queryset
if self.objects is not None:
return self.objects
raise Exception("Invalid batch")
[docs]
class Streamer:
"""
This class encapsulates all streaming logic related to a particular Django
model class
"""
topic: str | None = None
"Kafka topic to stream data."
exclude: Sequence[str] | None = None
"Data fields to exclude."
include: Sequence[str] | None = None
"List of extra (related, computed) fields to include."
static_fields: dict[str, Any] | None = None
"Static data to include in every message."
select_related: Sequence[str] | None = None
"List of related fields to select in queryset."
prefetch_related: Sequence[str] | None = None
"List of related fields to prefetch in queryset."
handle_related: Sequence[str] | None = None
"List of related fields to handle changes."
batch_class: type[Batch] = Batch
"Batch class."
refresh_finalize_type: RefreshFinalizeType = RefreshFinalizeType.ENUMERATE
"Which message type to use at the end when doing a full refresh \
(enumerate or EOS)."
batch_size: int | None = None
"Number of records in batch."
message_serializer: MessageSerializer | None = None
"Serializer function for message serialization. \
See `KafkaProducer documentation`_ for details."
partition_key_serializer: PartitionKeySerializer | None = None
"Partition key serializer function. See `KafkaProducer documentation`_ for details."
partitioner: Partitioner | None = None
"Partitioner function. See `KafkaProducer documentation`_ for details."
id_field: str = "id"
"Field name of object ID."
enumerate_ids_field: str = "ids"
"Field name for list of object IDs in enumerate message."
enumerate_chunk_field: str = "chunk"
"Field name for chunk in enumerate message."
enumerate_chunk_size: int = 5000
"Chunk size in enumerate message."
def __init__(self, **kwargs: Any):
"""
Streamer constructor.
"""
for key, value in kwargs.items():
if value is not None:
setattr(self, key, value)
if not self.topic:
raise ImproperlyConfigured("No streamer topic specified")
self.batch_size = self.batch_size or get_setting("BATCH_SIZE")
if self.message_serializer is None:
self.message_serializer = get_setting(
"DEFAULT_MESSAGE_SERIALIZER", resolve=True
)
if self.partition_key_serializer is None:
self.partition_key_serializer = get_setting(
"DEFAULT_PARTITION_KEY_SERIALIZER", resolve=True
)
if self.partitioner is None:
self.partitioner = get_setting("DEFAULT_PARTITIONER", resolve=True)
def get_data_for_object(self, obj: Model, batch: Batch) -> dict[str, Any]:
"""
Returns data fields for given object
"""
def get_concrete_fields(
obj: Model,
batch: Batch,
related_name: str | None = None,
exclude: Sequence[str] | None = None,
) -> dict[str, Any]:
if exclude and related_name:
exclude = [
f[len(related_name) + 1 :]
for f in exclude
if f.startswith(related_name + ".")
]
data = {}
for f in obj._meta.concrete_fields: # type: ignore
if exclude and f.name in exclude:
continue
if isinstance(f, FileField):
continue
if related_name:
method_name = "load_%s__%s" % (related_name, f.attname)
else:
method_name = "load_%s" % f.attname
func = getattr(self, method_name, None)
if func is not None:
value = func(obj, batch)
else:
value = getattr(obj, f.attname)
data[f.attname] = value
return data
data = get_concrete_fields(obj, batch, exclude=self.exclude)
if self.include:
for name in self.include:
method_name = "load_%s" % name
func = getattr(self, method_name, None)
try:
if func is not None:
value = func(obj, batch)
else:
value = getattr(obj, name)
except ObjectDoesNotExist:
value = None
if isinstance(value, Manager):
value = value.all()
if isinstance(value, (QuerySet, list, tuple)):
value_list = []
for sub_value in value:
if isinstance(sub_value, Model):
value_list.append(
get_concrete_fields(
sub_value,
batch,
related_name=name,
exclude=self.exclude,
)
)
else:
value_list.append(sub_value)
value = value_list
elif isinstance(value, Model):
value = get_concrete_fields(
value, batch, related_name=name, exclude=self.exclude
)
data[name] = value
return data
def get_id(self, obj: Model, batch: Batch) -> ObjectID:
if obj.pk is not None:
assert isinstance(obj.pk, ObjectID)
return obj.pk
obj_id = getattr(obj, "_kafkastreamer_pre_delete_pk")
assert isinstance(obj_id, ObjectID)
return obj_id
def get_message(
self,
obj: Model,
batch: Batch,
msg_type: str | None = None,
timestamp: datetime | None = None,
) -> Message:
"""
Returns Message tuple for given obj and message type
"""
if msg_type is None:
msg_type = TYPE_REFRESH
if timestamp is None:
timestamp = timezone.now()
meta = MessageMeta(
timestamp=timestamp,
msg_type=msg_type,
context=self.get_context_info(),
)
data = self.get_data_for_object(obj, batch)
extra = self.get_extra_data(obj, batch)
if extra:
data.update(extra)
obj_id = self.get_id(obj, batch)
msg = Message(meta=meta, obj_id=obj_id, data=data)
return msg
def get_delete_message(
self,
obj_id: ObjectID,
timestamp: datetime,
obj: Model | None = None,
batch: Batch | None = None,
) -> Message:
"""
Returns Message tuple for delete message type for given object ID
"""
meta = MessageMeta(
timestamp=timestamp,
msg_type=TYPE_DELETE,
context=self.get_context_info(),
)
data = {
self.id_field: obj_id,
}
extra = self.get_extra_data(obj, batch)
if extra:
data.update(extra)
msg = Message(meta=meta, obj_id=obj_id, data=data)
return msg
def get_enumerate_message(
self,
objects_ids: Sequence[ObjectID],
timestamp: datetime,
batch: Batch | None = None,
chunk_index: int | None = None,
chunk_total: int | None = None,
chunk_session: str | None = None,
) -> Message:
"""
Returns Message tuple for enumerate message type for given objects IDs
"""
meta = MessageMeta(
timestamp=timestamp,
msg_type=TYPE_ENUMERATE,
context=self.get_context_info(),
)
data: dict[str, Any] = {
self.enumerate_ids_field: objects_ids,
}
if chunk_index is not None and chunk_total and chunk_session:
data[self.enumerate_chunk_field] = {
"index": chunk_index,
"count": chunk_total,
"session": chunk_session,
}
extra = self.get_extra_data(None, batch)
if extra:
data.update(extra)
obj_id = objects_ids[0] if objects_ids else None
msg = Message(meta=meta, obj_id=obj_id, data=data)
return msg
def get_eos_message(self, timestamp: datetime) -> Message:
"""
Returns Message tuple for end of stream message type
"""
meta = MessageMeta(
timestamp=timestamp,
msg_type=TYPE_EOS,
context=self.get_context_info(),
)
msg = Message(meta=meta, obj_id=None, data={})
return msg
def get_context_info(self) -> MessageContext:
"""
Returns context information fields
"""
source = getattr(_context, "source", None) or get_setting("DEFAULT_SOURCE")
user = getattr(_context, "user", None)
if user is not None and user.is_authenticated():
user_id = user.pk
else:
user_id = None
context = MessageContext(
source=source,
user_id=user_id,
extra=None,
)
return context
def get_extra_data(
self, obj: Model | None, batch: Batch | None
) -> dict[str, Any] | None:
"""
Returns extra data fields for given object or batch. Default
implementation just returns `static_fields`.
"""
return self.static_fields
def get_batch(
self,
objects: Iterable[Model] | None = None,
queryset: QuerySet | None = None,
manager: Manager | None = None,
objects_ids: Sequence[ObjectID] | None = None,
**kwargs: Any,
) -> Batch:
return self.batch_class(
objects=objects,
queryset=queryset,
manager=manager,
objects_ids=objects_ids,
select_related=self.select_related,
prefetch_related=self.prefetch_related,
**kwargs,
)
def get_messages_for_batch(
self,
batch: Batch,
msg_type: str | None = None,
timestamp: datetime | None = None,
) -> Generator[Message, None]:
"""
Returns Message tuples for batch of objects
"""
try:
for obj in batch.get_objects():
yield self.get_message(
obj,
batch=batch,
msg_type=msg_type,
timestamp=timestamp,
)
except FieldError as e:
log.error("FieldError for model: %s: %s", batch.model, e)
raise
def get_messages_for_objects(
self,
objects: Iterable[Model],
manager: Manager | None = None,
objects_ids: Sequence[ObjectID] | None = None,
msg_type: str | None = None,
timestamp: datetime | None = None,
batch_size: int | None = None,
batch_kwargs: dict[str, Any] | None = None,
) -> Generator[Message, None]:
"""
Returns Message tuples for given objects with given message type
"""
if timestamp is None:
timestamp = timezone.now()
batch_size = batch_size or self.batch_size
queryset = None
if isinstance(objects, Manager):
manager = objects
queryset = objects.all()
elif isinstance(objects, QuerySet):
queryset = objects
if queryset is not None and batch_size:
if objects_ids is None:
ids = list(queryset.distinct().order_by().values_list("pk", flat=True))
else:
ids = list(objects_ids)
ids_chunked = [
ids[i : i + batch_size] for i in range(0, len(ids), batch_size)
]
for ids in ids_chunked:
batch = self.get_batch(
queryset=queryset,
manager=manager,
objects_ids=ids,
**(batch_kwargs or {}),
)
messages = self.get_messages_for_batch(
batch, msg_type=msg_type, timestamp=timestamp
)
for msg in messages:
yield msg
else:
batch = self.get_batch(objects=objects, manager=manager)
messages = self.get_messages_for_batch(
batch,
msg_type=msg_type,
timestamp=timestamp,
)
for msg in messages:
yield msg
def get_messages_for_ids_delete(
self,
objects_ids: Sequence[ObjectID],
timestamp: datetime | None = None,
manager: Manager | None = None,
) -> list[Message]:
"""
Returns Message tuples for delete messages for given objects IDs
"""
if timestamp is None:
timestamp = timezone.now()
batch = self.get_batch(objects_ids=objects_ids, manager=manager)
messages = [
self.get_delete_message(obj_id, timestamp, batch=batch)
for obj_id in objects_ids
]
return messages
def get_producer_options(self) -> dict[str, Any]:
return cast(dict[str, Any], get_setting("PRODUCER_OPTIONS"))
def get_producer(self, **kwargs: Any) -> KafkaProducer | None:
"""
Returns Kafka producer
"""
options = {
"value_serializer": self.message_serializer,
"key_serializer": self.partition_key_serializer,
"bootstrap_servers": get_setting("BOOTSTRAP_SERVERS"),
**(
{
"partitioner": self.partitioner,
}
if self.partitioner is not None
else {}
),
**self.get_producer_options(),
**kwargs,
}
if options.get("bootstrap_servers") is None:
raise ImproperlyConfigured(
"The `KAFKA_STREAMER['BOOTSTRAP_SERVERS']` is not configured."
)
if options["bootstrap_servers"] == []:
return None
try:
producer = KafkaProducer(**options)
except NoBrokersAvailable as e:
log.error("Kafka connect error: %s", e)
return None
return producer
def send_messages(
self,
messages: Iterable[Message],
batch_size: int | None = None,
producer: KafkaProducer | None = None,
flush: bool = True,
) -> int:
"""
Sends given messages to Kafka
"""
batch_size = batch_size or self.batch_size
if producer is None:
producer = self.get_producer()
if producer is None:
return 0
messages_send_count = 0
try:
for msg in messages:
if self.partition_key_serializer is not None:
key = msg
else:
key = None
producer.send(self.topic, msg, key=key)
messages_send_count += 1
if batch_size and messages_send_count % batch_size == 0:
producer.flush()
if flush:
producer.flush()
except KafkaTimeoutError as e:
log.error("Kafka connect error: %s", e)
return messages_send_count
def send_objects(
self,
objects: Iterable[Model],
manager: Manager | None = None,
objects_ids: Sequence[ObjectID] | None = None,
msg_type: str | None = None,
timestamp: datetime | None = None,
batch_size: int | None = None,
batch_kwargs: dict[str, Any] | None = None,
producer: KafkaProducer | None = None,
flush: bool = True,
) -> int:
"""
Sends given objects to Kafka
"""
messages = self.get_messages_for_objects(
objects,
manager=manager,
objects_ids=objects_ids,
msg_type=msg_type,
timestamp=timestamp,
batch_size=batch_size,
batch_kwargs=batch_kwargs,
)
return self.send_messages(
messages,
batch_size=batch_size,
producer=producer,
flush=flush,
)
def send_ids_delete(
self,
objects_ids: Sequence[ObjectID],
timestamp: datetime | None = None,
manager: Manager | None = None,
batch_size: int | None = None,
producer: KafkaProducer | None = None,
flush: bool = True,
) -> int:
"""
Sends delete messages for given objects IDs
"""
messages = self.get_messages_for_ids_delete(
objects_ids,
timestamp=timestamp,
manager=manager,
)
return self.send_messages(
messages,
batch_size=batch_size,
producer=producer,
flush=flush,
)
def send_ids_enumerate(
self,
objects_ids: Sequence[ObjectID],
timestamp: datetime | None = None,
manager: Manager | None = None,
producer: KafkaProducer | None = None,
flush: bool = True,
chunk_size: int | None = None,
) -> int:
"""
Sends enumerate message for given objects IDs
"""
if timestamp is None:
timestamp = timezone.now()
if chunk_size is None:
chunk_size = self.enumerate_chunk_size
batch = self.get_batch(manager=manager)
if len(objects_ids) <= chunk_size:
messages = [
self.get_enumerate_message(
objects_ids,
timestamp,
batch=batch,
),
]
else:
ids_chunked = [
objects_ids[i : i + chunk_size]
for i in range(0, len(objects_ids), chunk_size)
]
chunk_session = str(uuid.uuid4())
messages = [
self.get_enumerate_message(
ids,
timestamp,
batch=batch,
chunk_index=idx,
chunk_total=len(ids_chunked),
chunk_session=chunk_session,
)
for idx, ids in enumerate(ids_chunked)
]
return self.send_messages(messages, producer=producer, flush=flush)
def send_eos(
self,
timestamp: datetime | None = None,
producer: KafkaProducer | None = None,
flush: bool = True,
) -> int:
"""
Sends end of stream messages
"""
if timestamp is None:
timestamp = timezone.now()
msg = self.get_eos_message(timestamp=timestamp)
return self.send_messages([msg], producer=producer, flush=flush)