update tests, model manager, simplify the import command

This commit is contained in:
Tyrel Souza 2015-09-08 12:09:03 -05:00
parent 7a0cc6310e
commit e47f92d6c8
3 changed files with 39 additions and 33 deletions

View File

@ -1,8 +1,7 @@
__author__ = 'tyrel' __author__ = 'tyrel'
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
from django.contrib.auth.models import User from django.contrib.auth.models import User
from tracking.models import Track, Segment, Point from tracking.models import Track
from django.contrib.gis import geos
import xmltodict import xmltodict
import os import os
@ -31,36 +30,7 @@ class Command(BaseCommand):
def import_file(self, filename): def import_file(self, filename):
with open(filename) as fd: with open(filename) as fd:
obj = dict(xmltodict.parse(fd.read()))['gpx'] gpx = dict(xmltodict.parse(fd.read()))['gpx']
user = self.user user = self.user
create_points_segments_tracks(obj, user) Track.objects.create_from_gpx(gpx, user)
def create_points_segments_tracks(obj, user):
# Get some Meta data from the gpx file.
metadata = obj['metadata']
start_time = metadata['time']
# Get the the track and segments from the gpx file.
tracks = obj['trk']
name = tracks['name']
segments = tracks['trkseg']
# Check unimplemented features
if isinstance(segments, list):
raise NotImplemented(
"Haven't run across one that has multiple segments. "
"Contribute a patch if you have.")
elif isinstance(segments, dict):
track = Track.objects.create(user=user,
start=start_time,
name=name)
points = segments['trkpt']
segment = Segment(track=track, time=points[0]['time'])
segment.save()
for pt in points:
lat, lon = float(pt['@lat']), float(pt['@lon'])
poly_pt = geos.Point(lon, lat)
Point.objects.create(segment=segment,
time=pt['time'],
elevation=pt['ele'],
point=poly_pt)

View File

@ -1,14 +1,46 @@
from django.contrib.gis.db import models from django.contrib.gis.db import models
from django.contrib.gis import geos
from django.db import transaction
# Create your models here. # Create your models here.
from django.contrib.auth.models import User from django.contrib.auth.models import User
class TrackManager(models.Manager):
def create_from_gpx(cls, obj, user):
# Get some Meta data from the gpx file.
metadata = obj['metadata']
start_time = metadata['time']
# Get the the track and segments from the gpx file.
tracks = obj['trk']
name = tracks['name']
segments = tracks['trkseg']
# Check unimplemented features
with transaction.atomic():
track = Track.objects.create(user=user,
start=start_time,
name=name)
points = segments['trkpt']
segment = Segment(track=track, time=points[0]['time'])
segment.save()
for pt in points:
lat, lon = float(pt['@lat']), float(pt['@lon'])
poly_pt = geos.Point(lon, lat)
Point.objects.create(segment=segment,
time=pt['time'],
elevation=pt['ele'],
point=poly_pt)
return track
raise Exception("Track not created, rolling back DB")
class Track(models.Model): class Track(models.Model):
user = models.ForeignKey(User, blank=False) user = models.ForeignKey(User, blank=False)
start = models.DateTimeField() start = models.DateTimeField()
name = models.CharField(max_length=1024, blank=False) name = models.CharField(max_length=1024, blank=False)
description = models.TextField(blank=True) description = models.TextField(blank=True)
objects = TrackManager()
@property @property
def finish(self): def finish(self):

View File

@ -1,5 +1,6 @@
from django.test import TestCase from django.test import TestCase
from django.core import management from django.core import management
from tracking.models import Track
from django.core.management.base import CommandError from django.core.management.base import CommandError
# Create your tests here. # Create your tests here.
@ -19,6 +20,9 @@ class ImportGPXManagementTest(TestCase):
management.call_command('import_gpx', 'tyrel', '/noop/nodir/') management.call_command('import_gpx', 'tyrel', '/noop/nodir/')
def test_import(self): def test_import(self):
track_count = Track.objects.all().count()
management.call_command('import_gpx', 'tyrel', 'testing/example-ride.gpx') management.call_command('import_gpx', 'tyrel', 'testing/example-ride.gpx')
new_track_count = Track.objects.all().count()
self.assertNotEqual(track_count, new_track_count)