XLeRobot icon indicating copy to clipboard operation
XLeRobot copied to clipboard

[Contribute] Migrate LeRobot VLA Models to XLeRobot

Open Vector-Wangel opened this issue 4 months ago • 9 comments

📋 Issue Description

Migrate existing VLA models from LeRobot (originally for single SO-101 arms) to work with XLeRobot dual-arm mobile platform.

🎯 What We Need

Target Models Options: DP3, ACT, Pi0, SmolVLA, GR00T

Requirements:

  • Model inference on XLeRobot hardware/simulation
  • Action space mapping from single arm to dual-arm + base
  • Working demo with pick-and-place task
  • Basic documentation

🔧 Expected Approach

  1. Study LeRobot VLA implementations (refer to [LeRobot official website](https://github.com/huggingface/lerobot))
  2. Create action space adapter for XLeRobot format
  3. Test with pre-trained checkpoints
  4. Create demo script showing task execution

📝 Acceptance Criteria

  • [ ] Successfully run at least 1 VLA model on XLeRobot
  • [ ] Demo video
  • [ ] Basic setup documentation
  • [ ] Example usage script

💡 Resources

  • [LeRobot Models](https://github.com/huggingface/lerobot)
  • XLeRobot Examples
  • Dual-arm Control Reference

🤝 Propose Your Approach

Comment with your implementation plan! Include:

  • Which models you'll start with
  • How you'll handle action space mapping
  • Timeline estimate

This will bring powerful VLA capabilities to our dual-arm mobile platform! 🤖

Vector-Wangel avatar Aug 21 '25 00:08 Vector-Wangel

After speaking with Dor on this here's few pointers which i am putting down below:

  • Should have a Sort class which will provide the direction of the sort -> also sorting will be applied from the client side (via query param like sort=asc or sort=desc
  • For json_file handler the $gt filter on the cursor will not work -> the pagination for json_file needs to be done on the adapter itself while keeping the find like how it is (this part im still figuring out from design perspective)
  • Write DB and API test cases before - which i have started as it will help me with 2

Agam1997 avatar Sep 03 '25 05:09 Agam1997

New update:

In the API tests the test_that_sessions_can_be_listed remains as is and responsible to test the normal legacy non paginated response

Tests additions:

# File: /parlant/tests/api/test_sessions.py
async def test_that_sessions_are_paginated_when_listed(
    async_client: httpx.AsyncClient,
    container: Container
) -> None:
    agents = [
        await create_agent(container, "first-agent"),
    ]

    # 10 sessions
    sessions = []
    for i in range(10):
        session = await create_session(
            container, 
            agent_id=agents[0].id, 
            title=f"session-{i}"
        )
        sessions.append(session)

    response = await async_client.get("/sessions", params={"limit": 5})
    data = response.raise_for_status().json()
    
    assert "sessions" in data
    assert "next_cursor" in data
    assert "total_count" in data
    assert "has_more" in data
    assert len(data["sessions"]) == 5
    assert data["total_count"] == 10
    assert data["has_more"] is True

async def test_sessions_pagination_with_cursor(
    async_client: httpx.AsyncClient,
    container: Container,
) -> None:
    """Test cursor-based pagination navigation."""
    agent = await create_agent(container, "test-agent")
    
    # 7 sessions
    for i in range(7):
        await create_session(container, agent_id=agent.id, title=f"session-{i}")

    response1 = await async_client.get("/sessions", params={"limit": 3})
    data1 = response1.raise_for_status().json()
    
    assert len(data1["sessions"]) == 3
    assert data1["has_more"] is True
    assert data1["next_cursor"] is not None
    
    response2 = await async_client.get("/sessions", params={
        "cursor": data1["next_cursor"],
        "limit": 3
    })
    data2 = response2.raise_for_status().json()
    
    assert len(data2["sessions"]) == 3
    assert data2["has_more"] is True

    response3 = await async_client.get("/sessions", params={
        "cursor": data2["next_cursor"],
        "limit": 3
    })
    data3 = response3.raise_for_status().json()
    
    assert len(data3["sessions"]) == 1
    assert data3["has_more"] is False
    assert data3["next_cursor"] is None
    
    # Check if any overlap
    page1_ids = {s["id"] for s in data1["sessions"]}
    page2_ids = {s["id"] for s in data2["sessions"]}
    page3_ids = {s["id"] for s in data3["sessions"]}
    
    assert page1_ids.isdisjoint(page2_ids)
    assert page1_ids.isdisjoint(page3_ids)
    assert page2_ids.isdisjoint(page3_ids)

async def test_sessions_pagination_sort_directions(
    async_client: httpx.AsyncClient,
    container: Container,
) -> None:
    """Test ascending vs descending sort."""
    agent = await create_agent(container, "test-agent")
    
    sessions = []
    for i in range(7):
        session = await create_session(container, agent_id=agent.id, title=f"session-{i}")
        sessions.append(session)
        await asyncio.sleep(0.015)  # Small delay so entries have different creation_utc
    
    response_desc = await async_client.get("/sessions", params={"limit": 7, "sort": "desc"})
    data_desc = response_desc.raise_for_status().json()
    
    response_asc = await async_client.get("/sessions", params={"limit": 7, "sort": "asc"})
    data_asc = response_asc.raise_for_status().json()
    
    assert len(data_desc["sessions"]) == len(data_asc["sessions"])
    assert data_desc["sessions"][0]["id"] == data_asc["sessions"][-1]["id"]
    assert data_desc["sessions"][-1]["id"] == data_asc["sessions"][0]["id"]

async def test_sessions_pagination_with_filters(
    async_client: httpx.AsyncClient,
    container: Container,
) -> None:
    """Test pagination combined with existing filters."""
    agents = [
        await create_agent(container, "agent-1"),
        await create_agent(container, "agent-2"),
    ]
    
    # Create sessions with different agents
    for i in range(3):
        await create_session(container, agent_id=agents[0].id, title=f"agent1-session-{i}")
    for i in range(2):
        await create_session(container, agent_id=agents[1].id, title=f"agent2-session-{i}")
    
    response = await async_client.get("/sessions", params={
        "agent_id": agents[0].id,
        "limit": 2
    })
    data = response.raise_for_status().json()
    
    assert len(data["sessions"]) == 2
    assert data["total_count"] == 3
    assert data["has_more"] is True
    assert all(s["agent_id"] == agents[0].id for s in data["sessions"])

async def test_sessions_pagination_empty_results(
    async_client: httpx.AsyncClient,
    container: Container,
) -> None:
    """Test pagination with no sessions."""
    response = await async_client.get("/sessions", params={"limit": 10})
    data = response.raise_for_status().json()
    
    assert data["sessions"] == []
    assert data["total_count"] == 0
    assert data["has_more"] is False
    assert data["next_cursor"] is None

async def test_sessions_pagination_invalid_cursor(
    async_client: httpx.AsyncClient,
    container: Container,
) -> None:
    """Test pagination with invalid cursor."""
    agent = await create_agent(container, "test-agent")
    await create_session(container, agent_id=agent.id)
    
    response = await async_client.get("/sessions", params={
        "cursor": "invalid-cursor",
        "limit": 10
    })
    data = response.raise_for_status().json()
    
    assert len(data["sessions"]) == 1
    assert data["total_count"] == 1

Agam1997 avatar Sep 05 '25 11:09 Agam1997

Tests for adapters:

async def test_sessions_retrieval(context: _TestContext, new_file: Path) -> None:
    async with JSONFileDocumentDatabase(context.container[Logger], new_file) as session_db:
        async with SessionDocumentStore(session_db) as session_store:
            sessions = []
            for i in range(10):
                customer_id = CustomerId(f"test_customer_{i}")
                title = f"test_title_{i}"
                utc_now = datetime.now(timezone.utc)
                session = await session_store.create_session(
                    creation_utc=utc_now,
                    customer_id=customer_id,
                    agent_id=context.agent_id,
                    title=title
                )
                sessions.append(session)
            loaded_sessions = await session_store.list_sessions(agent_id=context.agent_id)
        loaded_sessions = list(loaded_sessions)

    assert len(loaded_sessions) == len(sessions)
    loaded_session = loaded_sessions[0]
    assert loaded_session.title == sessions[0].title
    assert loaded_session.customer_id == sessions[0].customer_id
    assert loaded_session.agent_id == context.agent_id

Working on the mongo one

Agam1997 avatar Sep 09 '25 14:09 Agam1997

@mc-dorzo should i start working on the implementation? are these tests enough for now? (do check the adapters one as im not sure if its right)

Agam1997 avatar Sep 09 '25 14:09 Agam1997

Hey team,

Firstly apologise as there isnt much progress in this issue yet. In the design phase we are contemplating a few things and appreciate inputs.

Firstly we have decided to go with cursor based pagination instead of skip and limit. We have also decided to add Sorting by using a Sort query params:

CursorQuery: TypeAlias = Annotated[
    Optional[SessionId],
    Query(
        description="Cursor for pagination - session ID to start after",
        examples=["sess_123yz"],
    ),
]

LimitQuery: TypeAlias = Annotated[
    Optional[int],
    Query(
        description="Maximum number of sessions to return",
        examples=[10, 50],
        ge=1,
        le=100,
    ),
]

SortQuery: TypeAlias = Annotated[
    Optional[str],
    Query(
       description="Sort direction for the sessions to be sorted",
       examples="asc",
     ),
]

Next have decided to add sort class like this:

from typing import NamedTuple, Literal

class Sort(NamedTuple):
    field: str
    direction: Literal["asc", "desc"] = "asc"
    
    @classmethod
    def by_creation_utc(cls, direction: Literal["asc", "desc"] = "asc") -> 'Sort':
        return cls("creation_utc", direction)
    
    @property
    def is_ascending(self) -> bool:
        return self.direction == "asc"
    
    @property
    def is_descending(self) -> bool:
        return self.direction == "desc"
    
    @property
    def mongo_direction(self) -> int:
        return 1 if self.is_ascending else -1

After a long thought about which field is best suitable for using as a cursor - had decided to go with the session.id as the cursor. However it comes with a challenge. Currently in the json_file document , the session ID is generated randomly so using session.id as the cursor with json_file poses a challenge. For other databases, the ID field is generally a primary index and incremental so a $gt filter would work pretty well for them, also to handle for the json_file I have decided to create a hash / encode with combination of session.id and session.creation_utc This generated string will then be decoded and we can use the right field as cursor, for now i have decided to use creation_utc - this has its own challenges which are discussed later

This would be in the common.py inside persistence folder

def encode_cursor(session_id: str, creation_utc: str) -> str:
    """Encode session ID and creation timestamp."""
    cursor_data = json.dumps({
        "id": session_id,
        "creation_utc": creation_utc
    }, sort_keys=True)
    return base64.urlsafe_b64encode(cursor_data.encode()).decode().rstrip('=')

def decode_cursor(cursor: str) -> Dict[str, str]:
    """Decode to get session ID and creation timestamp."""
    cursor += '=' * (4 - len(cursor) % 4)
    decoded = base64.urlsafe_b64decode(cursor).decode()
    return json.loads(decoded)

In addition it will also have these classes:

@dataclass(frozen=True)
class PaginationParams:
    cursor: Optional[str] = None
    limit: int = 50
    sort: Optional[Sort] = None

@dataclass(frozen=True)
class PaginationResult(Generic[T]):
    items: Sequence[T]
    next_cursor: Optional[str]
    total_count: int
    has_more: bool

I want to provide a way which is backwards compatible, people already using the existing Sessions API will continue to use the older one, and we will provide a migration guide to them to migrate to the new paginated API and so will add these 2 abstract methods in document_database.py

from parlant.core.persistence.common import PaginationParams, PaginationResult

class DocumentCollection(ABC, Generic[TDocument]):
    # ... existing methods unchanged ...

    @abstractmethod
    async def find_paginated(
        self,
        filters: Where,
        pagination: PaginationParams,
    ) -> PaginationResult[TDocument]:
        """Finds documents with cursor-based pagination."""
        ...

    @abstractmethod
    async def count(
        self,
        filters: Where,
    ) -> int:
        """Counts documents matching filters."""
        ...

Then in the SessionStore and SessionDocumentStore we make changes like this:

from parlant.core.persistence.common import PaginationParams, PaginationResult

class SessionStore(ABC):
    # ... existing methods unchanged ...

    @abstractmethod
    async def list_sessions_paginated(
        self,
        agent_id: Optional[AgentId] = None,
        customer_id: Optional[CustomerId] = None,
        pagination: PaginationParams = PaginationParams(),
    ) -> PaginationResult[Session]: ...

class SessionDocumentStore(SessionStore):
    # ... existing methods unchanged ...

    @override
    async def list_sessions_paginated(
        self,
        agent_id: Optional[AgentId] = None,
        customer_id: Optional[CustomerId] = None,
        pagination: PaginationParams = PaginationParams(),
    ) -> PaginationResult[Session]:
        async with self._lock.reader_lock:
            filters = {
                **({"agent_id": {"$eq": agent_id}} if agent_id else {}),
                **({"customer_id": {"$eq": customer_id}} if customer_id else {}),
            }
            
            result = await self._session_collection.find_paginated(
                filters=cast(Where, filters),
                pagination=pagination,
            )
            
            return PaginationResult(
                items=[self._deserialize_session(doc) for doc in result.items],
                next_cursor=result.next_cursor,
                total_count=result.total_count,
                has_more=result.has_more,
            )

And then on the adapters:

mongo_db.py

from parlant.core.persistence.common import (
    PaginationParams, 
    PaginationResult, 
    decode_cursor, 
    encode_cursor
)

class MongoDocumentCollection(DocumentCollection[TDocument]):
    # ... existing methods unchanged ...

    async def find_paginated(
        self,
        filters: Where,
        pagination: PaginationParams,
    ) -> PaginationResult[TDocument]:
     
        pipeline = []
        
        if filters:
            pipeline.append({"$match": filters})
        
        if pagination.cursor and pagination.sort:
            cursor_values = decode_cursor(pagination.cursor)
            cursor_conditions = []
            for sort_field in pagination.sort:
                field_name = sort_field["field"]
                if field_name in cursor_values:
                    op = "$gt" if sort_field["direction"] == "asc" else "$lt"
                    cursor_conditions.append({field_name: {op: cursor_values[field_name]}})
            
            if cursor_conditions:
                pipeline.append({"$match": {"$or": cursor_conditions}})
        
        if pagination.sort:
            sort_spec = {}
            for sort_field in pagination.sort:
                direction = 1 if sort_field["direction"] == "asc" else -1
                sort_spec[sort_field["field"]] = direction
            pipeline.append({"$sort": sort_spec})
        
        pipeline.append({
            "$facet": {
                "items": [{"$limit": pagination.limit + 1}],
                "total_count": [{"$count": "count"}]
            }
        })
        
        result = await self._collection.aggregate(pipeline).to_list()
        
        items = result[0]["items"]
        total_count = result[0]["total_count"][0]["count"] if result[0]["total_count"] else 0
        
        has_more = len(items) > pagination.limit
        if has_more:
            items = items[:pagination.limit]
        
        next_cursor = None
        if has_more and items and pagination.sort:
            last_item = items[-1]
            cursor_values = {sf["field"]: last_item.get(sf["field"]) for sf in pagination.sort}
            next_cursor = encode_cursor(cursor_values)
        
        return PaginationResult(
            items=items,
            next_cursor=next_cursor,
            total_count=total_count,
            has_more=has_more,
        )

    async def count(self, filters: Where) -> int:
        return await self._collection.count_documents(filters)

json_file

from parlant.core.persistence.common import (
    PaginationParams, 
    PaginationResult, 
    decode_cursor, 
    encode_cursor
)

class JSONFileDocumentCollection(DocumentCollection[TDocument]):
    # ... existing methods unchanged ...

    @override
    async def find_paginated(
        self,
        filters: Where,
        pagination: PaginationParams,
    ) -> PaginationResult[TDocument]:
        async with self._lock.reader_lock:
            filtered_docs = [doc for doc in self.documents if matches_filters(filters, doc)]
            total_count = len(filtered_docs)
            
            if pagination.sort:
                for sort_field in reversed(pagination.sort):
                    field_name = sort_field["field"]
                    reverse = sort_field["direction"] == "desc"
                    filtered_docs.sort(key=lambda x: x.get(field_name, ""), reverse=reverse)
            
            start_idx = 0
            if pagination.cursor and pagination.sort:
                cursor_values = decode_cursor(pagination.cursor)
                for i, doc in enumerate(filtered_docs):
                    if self._is_after_cursor(doc, cursor_values, pagination.sort):
                        start_idx = i
                        break
            
            end_idx = start_idx + pagination.limit
            items = filtered_docs[start_idx:end_idx]
            has_more = end_idx < len(filtered_docs)
            
            next_cursor = None
            if has_more and items and pagination.sort:
                last_item = items[-1]
                cursor_values = {sf["field"]: last_item.get(sf["field"]) for sf in pagination.sort}
                next_cursor = encode_cursor(cursor_values)
            
            return PaginationResult(
                items=items,
                next_cursor=next_cursor,
                total_count=total_count,
                has_more=has_more,
            )

    def _is_after_cursor(self, doc: TDocument, cursor_values: dict, sort: Sort) -> bool:
        for sort_field in sort:
            field_name = sort_field["field"]
            doc_value = doc.get(field_name)
            cursor_value = cursor_values.get(field_name)
            
            if doc_value != cursor_value:
                if sort_field["direction"] == "asc":
                    return doc_value > cursor_value
                else:
                    return doc_value < cursor_value
        return False

    @override
    async def count(self, filters: Where) -> int:
        async with self._lock.reader_lock:
            return len([doc for doc in self.documents if matches_filters(filters, doc)])

Challenge

Using the creation_utc has its own pitfalls.

At scale with over 10milion records, there can be SIMILAR or SAME records with overlapping creation_utc time stamps, with more and more number of new sessions getting created it may be possible that there can be overlaps or collisions - Its coming down to around 19 collisions per 10milion records / month.

Another challenge or performance overhead of using creation_utc is , its a string type. the $gt performs a lexicographical comparison - which goes char over char one by one.

Thank you @mc-dorzo for your help continued help, I really appreciate it.

@kichanyurd what do you think?

Agam1997 avatar Sep 10 '25 15:09 Agam1997

I did some research on the challenges to using creation_utc as the cursor and here's my findings:

Most production deployments of Parlant will use (or should use) mongodb for persistence. Having an index on the creation_utc field will dramatically improve the performance:

Case: 10 milion records in sessions collection No index:

# MongoDB has to scan EVERY document
db.sessions.find({"creation_utc": {"$lt": "2024-03-24T10:00:00Z"}})
# Scans: 10,000,000 documents
# Time: ~30+ seconds

With index:

db.sessions.createIndex({"creation_utc": -1})  # Descending index

# Same query now uses index
db.sessions.find({"creation_utc": {"$lt": "2024-03-24T10:00:00Z"}})
# Index lookup: O(log n) = ~23 operations for 10M records
# Time: ~5-50ms

Lexicographic comparison IS fast because:

  • Index is pre-sorted - MongoDB doesn't compare strings during query
  • Binary search - O(log n) to find position
  • Sequential read - Once position found, just read next N records

Important thing to note with using indexes is it comes with its own storage overhead. For 10M records and index on creation_utc the storage overhead for the index would be :

Each entry: ~32 bytes (timestamp string + document pointer) Total index: 10M × 32 bytes = ~320MB

which is an acceptable overhead for the performance gain.

Coming to the json_file document DB - Which I feel is a completly different issue to solve as:

class JSONFileDocumentCollection(DocumentCollection[TDocument]):
    def __init__(
        self,
        database: JSONFileDocumentDatabase,
        name: str,
        schema: type[TDocument],
        data: Sequence[TDocument] | None = None,
    ) -> None:
        self._database = database
        self._name = name
        self._schema = schema
        self._op_counter = 0

        self._lock = ReaderWriterLock()

        self.documents = list(data) if data else []

As the data gets loaded into memory entirely, with large datasets this gets increasingly problematic. JSON file DB is fundamentally not designed for large datasets. To fix for this I would suggest moving to SQLite as the DB as the advantages are huge

let me know what you guys think.

Agam1997 avatar Sep 11 '25 04:09 Agam1997

First of all, @Agam1997, great job on the deep research—you’re one of a kind!

Two notes: 1. Could you please update the interface here for the case where we actually want to extend the find() function in DocumentCollection? We need it to remain backward compatible, but that’s possible as long as any newly introduced parameters are optional. 2. Question: I noticed you wrote creation_utc with a format that excludes microseconds. Do you think that could make a difference?

mc-dorzo avatar Sep 11 '25 11:09 mc-dorzo

class DocumentCollection(ABC, Generic[TDocument]):
    @abstractmethod
    async def find(
        self,
        filters: Where,
        sort: Optional[Sort] = None,
        limit: Optional[int] = None,
        cursor: Optional[str] = None,
    ) -> Union[Sequence[TDocument], PaginationResult[TDocument]]:
        """Finds documents. Returns PaginationResult if pagination params provided."""
        ...

Would look like this i suppose

class JSONFileDocumentCollection(DocumentCollection[TDocument]):
    @override
    async def find(
        self,
        filters: Where,
        sort: Optional[Sort] = None,
        limit: Optional[int] = None,
        cursor: Optional[str] = None,
    ) -> Union[Sequence[TDocument], PaginationResult[TDocument]]:
        
        is_paginated = sort is not None or limit is not None or cursor is not None
        
        async with self._lock.reader_lock:
            filtered_docs = [doc for doc in self.documents if matches_filters(filters, doc)]
            
            if not is_paginated:
                # Legacy behavior
                return filtered_docs
            
            total_count = len(filtered_docs)
            
            if sort:
                for sort_field in reversed(sort):
                    field_name = sort_field["field"]
                    reverse = sort_field["direction"] == "desc"
                    filtered_docs.sort(key=lambda x: x.get(field_name, ""), reverse=reverse)
            
            if cursor and sort:
                cursor_data = decode_cursor(cursor)
                
                # Find starting position
                start_idx = 0
                for i, doc in enumerate(filtered_docs):
                    should_include = True
                    
                    for sort_field in sort:
                        field_name = sort_field["field"]
                        doc_value = doc.get(field_name)
                        cursor_value = cursor_data.get(field_name)
                        
                        if doc_value != cursor_value:
                            if sort_field["direction"] == "desc":
                                # For descending
                                should_include = doc_value < cursor_value
                            else:
                                # For ascending
                                should_include = doc_value > cursor_value
                            break
                    
                    if should_include:
                        start_idx = i
                        break

                # Apply
                filtered_docs = filtered_docs[start_idx:]
            
            items = filtered_docs[:limit] if limit else filtered_docs
            has_more = limit and len(filtered_docs) > limit
            
            next_cursor = None
            if has_more and items and sort:
                last_item = items[-1]
                cursor_values = {sf["field"]: last_item.get(sf["field"]) for sf in sort}
                next_cursor = encode_cursor(cursor_values)
            
            return PaginationResult(
                items=items,
                next_cursor=next_cursor,
                total_count=total_count,
                has_more=has_more,
            )

Is the return type Union[Sequence[TDocument], PaginationResult[TDocument]] correct? just feels very non-intuitive to me so i might've made a number of mistakes here

Agam1997 avatar Sep 12 '25 04:09 Agam1997