django-rest-framework-csv
django-rest-framework-csv copied to clipboard
Suggestion of mixin
Hello
i created a great mixin, implementing this render, and wanted to share it. this mixins allow to renderer fixed fields (seted on view), all fields (getting in model), dynamic fields (received on request)
To filter data
change csv_fixed_fields to False
send get param: fields={id, fullname}
to exclude one or more field to csv use negation in field name: fields={-id, -fulaname}
To fixed fields (all fields of model)
change csv_fixed_fields to True
To fixed fields (manually setted)
change csv_fixed_fields to a list ["id", ...]
Not implemented
- filter and labels for nested data.
make sense create pr with this file ?
implementation:
from rest_framework.settings import api_settings
from rest_framework_csv.renderers import CSVRenderer
class CsvRenderMixin:
renderer_classes = (*api_settings.DEFAULT_RENDERER_CLASSES, CSVRenderer)
csv_fixed_fields = False
csv_filter_param = 'fields'
def paginate_queryset(self, queryset):
if self.paginator and self.request.accepted_renderer.format == "csv":
self.paginator.page_size = 99999
return super().paginate_queryset(queryset)
def get_renderer_context(self):
if self.request.accepted_renderer.format == 'csv':
context = super().get_renderer_context()
context['header'] = self.get_allowed_fields()
context['labels'] = self.mount_labels()
return context
return super().get_renderer_context()
def mount_labels(self):
return dict((f.name, f.verbose_name) for f in self.queryset.model._meta.fields)
def get_existing_fields(self):
return [f.name for f in self.queryset.model._meta.fields] if self.queryset.model else []
def get_allowed_fields(self):
fields = self.get_existing_fields()
if self.csv_fixed_fields:
return self.csv_fixed_fields if self.csv_fixed_fields is not True else fields
request_fields = self.request.GET.get(self.csv_filter_param, '').replace('{', '').replace('}', '').split(',')
if len(request_fields) > 0 and request_fields != ['']:
include = list(filter(lambda x: not x.startswith("-") and x != '', request_fields))
exclude = list(filter(lambda x: x.startswith("-") and x != '', request_fields))
if len(include) > 0:
# Drop any fields that are not specified on `fields` argument.
allowed = set(include)
existing = set(fields)
not_allowed = existing.symmetric_difference(allowed)
for field_name in not_allowed:
fields.remove(field_name)
if len(exclude) > 0:
# Drop any fields that are not specified on `exclude` argument.
not_allowed = set(exclude)
for field_name in not_allowed:
fields.remove(field_name[1:])
return fields